当前位置:   article > 正文

基于双向LSTM模型完成文本分类任务

基于双向LSTM模型完成文本分类任务

  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from torch.utils.data import Dataset
  5. from utils.data import load_vocab
  6. from functools import partial
  7. import time
  8. import random
  9. import numpy as np
  10. from nndl.runner import RunnerV3
  11. from nndl.metric import Accuracy
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. def load_imdb_data(path):
  14. assert os.path.exists(path)
  15. trainset, devset, testset = [], [], []
  16. with open(os.path.join(path, "train.txt"), "r", encoding='utf-8') as fr:
  17. for line in fr:
  18. sentence_label, sentence = line.strip().lower().split("\t", maxsplit=1)
  19. trainset.append((sentence, sentence_label))
  20. with open(os.path.join(path, "dev.txt"), "r", encoding='utf-8') as fr:
  21. for line in fr:
  22. sentence_label, sentence = line.strip().lower().split("\t", maxsplit=1)
  23. devset.append((sentence, sentence_label))
  24. with open(os.path.join(path, "test.txt"), "r", encoding='utf-8') as fr:
  25. for line in fr:
  26. sentence_label, sentence = line.strip().lower().split("\t", maxsplit=1)
  27. testset.append((sentence, sentence_label))
  28. return trainset, devset, testset
  29. # 加载IMDB数据集
  30. train_data, dev_data, test_data = load_imdb_data("./dataset/")
  31. # # 打印一下加载后的数据样式
  32. print(train_data[4])
  33. class IMDBDataset(Dataset):
  34. def __init__(self, examples, word2id_dict):
  35. super(IMDBDataset, self).__init__()
  36. # 词典,用于将单词转为字典索引的数字
  37. self.word2id_dict = word2id_dict
  38. # 加载后的数据集
  39. self.examples = self.words_to_id(examples)
  40. def words_to_id(self, examples):
  41. tmp_examples = []
  42. for idx, example in enumerate(examples):
  43. seq, label = example
  44. # 将单词映射为字典索引的ID, 对于词典中没有的单词用[UNK]对应的ID进行替代
  45. seq = [self.word2id_dict.get(word, self.word2id_dict['[UNK]']) for word in seq.split(" ")]
  46. label = int(label)
  47. tmp_examples.append([seq, label])
  48. return tmp_examples
  49. def __getitem__(self, idx):
  50. seq, label = self.examples[idx]
  51. return seq, label
  52. def __len__(self):
  53. return len(self.examples)
  54. # 加载词表
  55. word2id_dict = load_vocab("./dataset/vocab.txt")
  56. # 实例化Dataset
  57. train_set = IMDBDataset(train_data, word2id_dict)
  58. dev_set = IMDBDataset(dev_data, word2id_dict)
  59. test_set = IMDBDataset(test_data, word2id_dict)
  60. print('训练集样本数:', len(train_set))
  61. print('样本示例:', train_set[4])
  62. import os
  63. def load_vocab(path):
  64. assert os.path.exists(path)
  65. words = []
  66. with open(path, "r", encoding="utf-8") as f:
  67. words = f.readlines()
  68. words = [word.strip() for word in words if word.strip()]
  69. word2id = dict(zip(words, range(len(words))))
  70. return word2id
  71. def collate_fn(batch_data, pad_val=0, max_seq_len=256):
  72. seqs, seq_lens, labels = [], [], []
  73. max_len = 0
  74. for example in batch_data:
  75. seq, label = example
  76. # 对数据序列进行截断
  77. seq = seq[:max_seq_len]
  78. # 对数据截断并保存于seqs中
  79. seqs.append(seq)
  80. seq_lens.append(len(seq))
  81. labels.append(label)
  82. # 保存序列最大长度
  83. max_len = max(max_len, len(seq))
  84. # 对数据序列进行填充至最大长度
  85. for i in range(len(seqs)):
  86. seqs[i] = seqs[i] + [pad_val] * (max_len - len(seqs[i]))
  87. # return (torch.tensor(seqs), torch.tensor(seq_lens)), torch.tensor(labels)
  88. return (torch.tensor(seqs).to(device), torch.tensor(seq_lens)), torch.tensor(labels).to(device)
  89. max_seq_len = 5
  90. batch_data = [[[1, 2, 3, 4, 5, 6], 1], [[2, 4, 6], 0]]
  91. (seqs, seq_lens), labels = collate_fn(batch_data, pad_val=word2id_dict["[PAD]"], max_seq_len=max_seq_len)
  92. print("seqs: ", seqs)
  93. print("seq_lens: ", seq_lens)
  94. print("labels: ", labels)
  95. max_seq_len = 256
  96. batch_size = 128
  97. collate_fn = partial(collate_fn, pad_val=word2id_dict["[PAD]"], max_seq_len=max_seq_len)
  98. train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
  99. shuffle=True, drop_last=False, collate_fn=collate_fn)
  100. dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size,
  101. shuffle=False, drop_last=False, collate_fn=collate_fn)
  102. test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
  103. shuffle=False, drop_last=False, collate_fn=collate_fn)
  104. class AveragePooling(nn.Module):
  105. def __init__(self):
  106. super(AveragePooling, self).__init__()
  107. def forward(self, sequence_output, sequence_length):
  108. # 假设 sequence_length 是一个 PyTorch 张量
  109. sequence_length = sequence_length.unsqueeze(-1).to(torch.float32)
  110. # 根据sequence_length生成mask矩阵,用于对Padding位置的信息进行mask
  111. max_len = sequence_output.shape[1]
  112. mask = torch.arange(max_len, device='cuda') < sequence_length.to('cuda')
  113. mask = mask.to(torch.float32).unsqueeze(-1)
  114. # 对序列中paddling部分进行mask
  115. sequence_output = torch.multiply(sequence_output, mask.to('cuda'))
  116. # 对序列中的向量取均值
  117. batch_mean_hidden = torch.divide(torch.sum(sequence_output, dim=1), sequence_length.to('cuda'))
  118. return batch_mean_hidden
  119. class Model_BiLSTM_FC(nn.Module):
  120. def __init__(self, num_embeddings, input_size, hidden_size, num_classes=2):
  121. super(Model_BiLSTM_FC, self).__init__()
  122. # 词典大小
  123. self.num_embeddings = num_embeddings
  124. # 单词向量的维度
  125. self.input_size = input_size
  126. # LSTM隐藏单元数量
  127. self.hidden_size = hidden_size
  128. # 情感分类类别数量
  129. self.num_classes = num_classes
  130. # 实例化嵌入层
  131. self.embedding_layer = nn.Embedding(num_embeddings, input_size, padding_idx=0)
  132. # 实例化LSTM层
  133. self.lstm_layer = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
  134. # 实例化聚合层
  135. self.average_layer = AveragePooling()
  136. # 实例化输出层
  137. self.output_layer = nn.Linear(hidden_size * 2, num_classes)
  138. def forward(self, inputs):
  139. # 对模型输入拆分为序列数据和mask
  140. input_ids, sequence_length = inputs
  141. # 获取词向量
  142. inputs_emb = self.embedding_layer(input_ids)
  143. packed_input = nn.utils.rnn.pack_padded_sequence(inputs_emb, sequence_length.cpu(), batch_first=True,
  144. enforce_sorted=False)
  145. # 使用lstm处理数据
  146. packed_output, _ = self.lstm_layer(packed_input)
  147. # 解包输出
  148. sequence_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
  149. # 使用聚合层聚合sequence_output
  150. batch_mean_hidden = self.average_layer(sequence_output, sequence_length)
  151. # 输出文本分类logits
  152. logits = self.output_layer(batch_mean_hidden)
  153. return logits
  154. import torch
  155. import matplotlib.pyplot as plt
  156. class RunnerV3(object):
  157. def __init__(self, model, optimizer, loss_fn, metric, **kwargs):
  158. self.model = model
  159. self.optimizer = optimizer
  160. self.loss_fn = loss_fn
  161. self.metric = metric # 只用于计算评价指标
  162. # 记录训练过程中的评价指标变化情况
  163. self.dev_scores = []
  164. # 记录训练过程中的损失函数变化情况
  165. self.train_epoch_losses = [] # 一个epoch记录一次loss
  166. self.train_step_losses = [] # 一个step记录一次loss
  167. self.dev_losses = []
  168. # 记录全局最优指标
  169. self.best_score = 0
  170. def train(self, train_loader, dev_loader=None, **kwargs):
  171. # 将模型切换为训练模式
  172. self.model.train()
  173. # 传入训练轮数,如果没有传入值则默认为0
  174. num_epochs = kwargs.get("num_epochs", 0)
  175. # 传入log打印频率,如果没有传入值则默认为100
  176. log_steps = kwargs.get("log_steps", 100)
  177. # 评价频率
  178. eval_steps = kwargs.get("eval_steps", 0)
  179. # 传入模型保存路径,如果没有传入值则默认为"best_model.pdparams"
  180. save_path = kwargs.get("save_path", "best_model.pdparams")
  181. custom_print_log = kwargs.get("custom_print_log", None)
  182. # 训练总的步数
  183. num_training_steps = num_epochs * len(train_loader)
  184. if eval_steps:
  185. if self.metric is None:
  186. raise RuntimeError('Error: Metric can not be None!')
  187. if dev_loader is None:
  188. raise RuntimeError('Error: dev_loader can not be None!')
  189. # 运行的step数目
  190. global_step = 0
  191. total_acces = []
  192. total_losses = []
  193. Iters = []
  194. # 进行num_epochs轮训练
  195. for epoch in range(num_epochs):
  196. # 用于统计训练集的损失
  197. total_loss = 0
  198. for step, data in enumerate(train_loader):
  199. X, y = data
  200. # 获取模型预测
  201. # 计算logits
  202. logits = self.model(X)
  203. # 将y转换为和logits相同的形状
  204. acc_y = y.view(-1, 1)
  205. # 计算准确率
  206. probs = torch.softmax(logits, dim=1)
  207. pred = torch.argmax(probs, dim=1)
  208. correct = (pred == acc_y).sum().item()
  209. total = acc_y.size(0)
  210. acc = correct / total
  211. total_acces.append(acc)
  212. # print(acc.numpy()[0])
  213. loss = self.loss_fn(logits, y) # 默认求mean
  214. total_loss += loss
  215. total_losses.append(loss.item())
  216. Iters.append(global_step)
  217. # 训练过程中,每个step的loss进行保存
  218. self.train_step_losses.append((global_step, loss.item()))
  219. if log_steps and global_step % log_steps == 0:
  220. print(
  221. f"[Train] epoch: {epoch}/{num_epochs}, step: {global_step}/{num_training_steps}, loss: {loss.item():.5f}")
  222. # 梯度反向传播,计算每个参数的梯度值
  223. loss.backward()
  224. if custom_print_log:
  225. custom_print_log(self)
  226. # 小批量梯度下降进行参数更新
  227. self.optimizer.step()
  228. # 梯度归零
  229. self.optimizer.zero_grad()
  230. # 判断是否需要评价
  231. if eval_steps > 0 and global_step != 0 and \
  232. (global_step % eval_steps == 0 or global_step == (num_training_steps - 1)):
  233. dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)
  234. print(f"[Evaluate] dev score: {dev_score:.5f}, dev loss: {dev_loss:.5f}")
  235. # 将模型切换为训练模式
  236. self.model.train()
  237. # 如果当前指标为最优指标,保存该模型
  238. if dev_score > self.best_score:
  239. self.save_model(save_path)
  240. print(
  241. f"[Evaluate] best accuracy performence has been updated: {self.best_score:.5f} --> {dev_score:.5f}")
  242. self.best_score = dev_score
  243. global_step += 1
  244. # 当前epoch 训练loss累计值
  245. trn_loss = (total_loss / len(train_loader)).item()
  246. # epoch粒度的训练loss保存
  247. self.train_epoch_losses.append(trn_loss)
  248. draw_process("trainning acc", "green", Iters, total_acces, "trainning acc")
  249. print("total_acc:")
  250. print(total_acces)
  251. print("total_loss:")
  252. print(total_losses)
  253. print("[Train] Training done!")
  254. # 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
  255. @torch.no_grad()
  256. def evaluate(self, dev_loader, **kwargs):
  257. assert self.metric is not None
  258. # 将模型设置为评估模式
  259. self.model.eval()
  260. global_step = kwargs.get("global_step", -1)
  261. # 用于统计训练集的损失
  262. total_loss = 0
  263. # 重置评价
  264. self.metric.reset()
  265. # 遍历验证集每个批次
  266. for batch_id, data in enumerate(dev_loader):
  267. X, y = data
  268. # 计算模型输出
  269. logits = self.model(X)
  270. # 计算损失函数
  271. loss = self.loss_fn(logits, y).item()
  272. # 累积损失
  273. total_loss += loss
  274. # 累积评价
  275. self.metric.update(logits, y)
  276. dev_loss = (total_loss / len(dev_loader))
  277. self.dev_losses.append((global_step, dev_loss))
  278. dev_score = self.metric.accumulate()
  279. self.dev_scores.append(dev_score)
  280. return dev_score, dev_loss
  281. # 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
  282. @torch.no_grad()
  283. def predict(self, x, **kwargs):
  284. # 将模型设置为评估模式
  285. self.model.eval()
  286. # 运行模型前向计算,得到预测值
  287. logits = self.model(x)
  288. return logits
  289. def save_model(self, save_path):
  290. torch.save(self.model.state_dict(), save_path)
  291. def load_model(self, model_path):
  292. model_state_dict = torch.load(model_path)
  293. self.model.load_state_dict(model_state_dict)
  294. class Accuracy():
  295. def __init__(self, is_logist=True):
  296. # 用于统计正确的样本个数
  297. self.num_correct = 0
  298. # 用于统计样本的总数
  299. self.num_count = 0
  300. self.is_logist = is_logist
  301. def update(self, outputs, labels):
  302. # 判断是二分类任务还是多分类任务,shape[1]=1时为二分类任务,shape[1]>1时为多分类任务
  303. if outputs.shape[1] == 1: # 二分类
  304. outputs = torch.squeeze(outputs, dim=-1)
  305. if self.is_logist:
  306. # logist判断是否大于0
  307. preds = torch.tensor((outputs >= 0), dtype=torch.float32)
  308. else:
  309. # 如果不是logist,判断每个概率值是否大于0.5,当大于0.5时,类别为1,否则类别为0
  310. preds = torch.tensor((outputs >= 0.5), dtype=torch.float32)
  311. else:
  312. # 多分类时,使用'torch.argmax'计算最大元素索引作为类别
  313. preds = torch.argmax(outputs, dim=1)
  314. # 获取本批数据中预测正确的样本个数
  315. labels = torch.squeeze(labels, dim=-1)
  316. batch_correct = torch.sum(torch.tensor(preds == labels, dtype=torch.float32)).cpu().numpy()
  317. batch_count = len(labels)
  318. # 更新num_correct 和 num_count
  319. self.num_correct += batch_correct
  320. self.num_count += batch_count
  321. def accumulate(self):
  322. # 使用累计的数据,计算总的指标
  323. if self.num_count == 0:
  324. return 0
  325. return self.num_correct / self.num_count
  326. def reset(self):
  327. # 重置正确的数目和总数
  328. self.num_correct = 0
  329. self.num_count = 0
  330. def name(self):
  331. return "Accuracy"
  332. def draw_process(title, color, iters, data, label):
  333. plt.title(title, fontsize=24)
  334. plt.xlabel("iter", fontsize=20)
  335. plt.ylabel(label, fontsize=20)
  336. plt.plot(iters, data, color=color, label=label)
  337. plt.legend()
  338. plt.grid()
  339. print(plt.show())
  340. np.random.seed(0)
  341. random.seed(0)
  342. torch.seed()
  343. # 指定训练轮次
  344. num_epochs = 3
  345. # 指定学习率
  346. learning_rate = 0.001
  347. # 指定embedding的数量为词表长度
  348. num_embeddings = len(word2id_dict)
  349. # embedding向量的维度
  350. input_size = 256
  351. # LSTM网络隐状态向量的维度
  352. hidden_size = 256
  353. # 实例化模型
  354. model = Model_BiLSTM_FC(num_embeddings, input_size, hidden_size).to(device)
  355. # 指定优化器
  356. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
  357. # 指定损失函数
  358. loss_fn = nn.CrossEntropyLoss()
  359. # 指定评估指标
  360. metric = Accuracy()
  361. # 实例化Runner
  362. runner = RunnerV3(model, optimizer, loss_fn, metric)
  363. # 模型训练
  364. start_time = time.time()
  365. runner.train(train_loader, dev_loader, num_epochs=num_epochs, eval_steps=10, log_steps=10,
  366. save_path="./checkpoints/best.pdparams")
  367. end_time = time.time()
  368. print("time: ", (end_time - start_time))
  369. from nndl.tools import plot_training_loss_acc
  370. # 图像名字
  371. fig_name = "./images/6.16.pdf"
  372. # sample_step: 训练损失的采样step,即每隔多少个点选择1个点绘制
  373. # loss_legend_loc: loss 图像的图例放置位置
  374. # acc_legend_loc: acc 图像的图例放置位置
  375. plot_training_loss_acc(runner, fig_name, fig_size=(16, 6), sample_step=10, loss_legend_loc="lower left",
  376. acc_legend_loc="lower right")
  377. model_path = "./checkpoints/best.pdparams"
  378. runner.load_model(model_path)
  379. accuracy, _ = runner.evaluate(test_loader)
  380. print(f"Evaluate on test set, Accuracy: {accuracy:.5f}")
  381. id2label={0:"消极情绪", 1:"积极情绪"}
  382. text = "this movie is so great. I watched it three times already"
  383. # 处理单条文本
  384. sentence = text.split(" ")
  385. words = [word2id_dict[word] if word in word2id_dict else word2id_dict['[UNK]'] for word in sentence]
  386. words = words[:max_seq_len]
  387. sequence_length = torch.tensor([len(words)], dtype=torch.int64)
  388. words = torch.tensor(words, dtype=torch.int64).unsqueeze(0)
  389. # 使用模型进行预测
  390. logits = runner.predict((words.to(device), sequence_length.to(device)))
  391. max_label_id = torch.argmax(logits, dim=-1).cpu().numpy()[0]
  392. pred_label = id2label[max_label_id]
  393. print("Label: ", pred_label)

C:\Users\48163\anaconda3\envs\env_torch\python.exe C:\Users\48163\PycharmProjects\基于双向LSTM模型完成文本分类任务\main.py 
("the premise of an african-american female scrooge in the modern, struggling city was inspired, but nothing else in this film is. here, ms. scrooge is a miserly banker who takes advantage of the employees and customers in the largely poor and black neighborhood it inhabits. there is no doubt about the good intentions of the people involved. part of the problem is that story's roots don't translate well into the urban setting of this film, and the script fails to make the update work. also, the constant message about sharing and giving is repeated so endlessly, the audience becomes tired of it well before the movie reaches its familiar end. this is a message film that doesn't know when to quit. in the title role, the talented cicely tyson gives an overly uptight performance, and at times lines are difficult to understand. the charles dickens novel has been adapted so many times, it's a struggle to adapt it in a way that makes it fresh and relevant, in spite of its very relevant message.", '0')
训练集样本数: 25000
样本示例: ([2, 976, 5, 32, 6860, 618, 7673, 8, 2, 13073, 2525, 724, 14, 22837, 18, 164, 416, 8, 10, 24, 701, 611, 1743, 7673, 7, 3, 56391, 21652, 36, 271, 3495, 5, 2, 11373, 4, 13244, 8, 2, 2157, 350, 4, 328, 4118, 12, 48810, 52, 7, 60, 860, 43, 2, 56, 4393, 5, 2, 89, 4152, 182, 5, 2, 461, 7, 11, 7321, 7730, 86, 7931, 107, 72, 2, 2830, 1165, 5, 10, 151, 4, 2, 272, 1003, 6, 91, 2, 10491, 912, 826, 2, 1750, 889, 43, 6723, 4, 647, 7, 2535, 38, 39222, 2, 357, 398, 1505, 5, 12, 107, 179, 2, 20, 4279, 83, 1163, 692, 10, 7, 3, 889, 24, 11, 141, 118, 50, 6, 28642, 8, 2, 490, 1469, 2, 1039, 98975, 24541, 344, 32, 2074, 11852, 1683, 4, 29, 286, 478, 22, 823, 6, 5222, 2, 1490, 6893, 883, 41, 71, 3254, 38, 100, 1021, 44, 3, 1700, 6, 8768, 12, 8, 3, 108, 11, 146, 12, 1761, 4, 92295, 8, 2641, 5, 83, 49, 3866, 5352], 0)
seqs:  tensor([[1, 2, 3, 4, 5],
        [2, 4, 6, 0, 0]], device='cuda:0')
seq_lens:  tensor([5, 3])
labels:  tensor([1, 0], device='cuda:0')
[Train] epoch: 0/3, step: 0/588, loss: 0.69439
[Train] epoch: 0/3, step: 10/588, loss: 0.67324

  batch_correct = torch.sum(torch.tensor(preds == labels, dtype=torch.float32)).cpu().numpy()
[Evaluate]  dev score: 0.64816, dev loss: 0.67321
[Evaluate] best accuracy performence has been updated: 0.00000 --> 0.64816
[Train] epoch: 0/3, step: 20/588, loss: 0.62388
[Evaluate]  dev score: 0.67872, dev loss: 0.60929
[Evaluate] best accuracy performence has been updated: 0.64816 --> 0.67872
[Train] epoch: 0/3, step: 30/588, loss: 0.57766
[Evaluate]  dev score: 0.66448, dev loss: 0.60312
[Train] epoch: 0/3, step: 40/588, loss: 0.56351
[Evaluate]  dev score: 0.70584, dev loss: 0.58985
[Evaluate] best accuracy performence has been updated: 0.67872 --> 0.70584
[Train] epoch: 0/3, step: 50/588, loss: 0.55316
[Evaluate]  dev score: 0.71792, dev loss: 0.55092
[Evaluate] best accuracy performence has been updated: 0.70584 --> 0.71792
[Train] epoch: 0/3, step: 60/588, loss: 0.53795
[Evaluate]  dev score: 0.73864, dev loss: 0.52103
[Evaluate] best accuracy performence has been updated: 0.71792 --> 0.73864
[Train] epoch: 0/3, step: 70/588, loss: 0.46456
[Evaluate]  dev score: 0.75760, dev loss: 0.50617
[Evaluate] best accuracy performence has been updated: 0.73864 --> 0.75760
[Train] epoch: 0/3, step: 80/588, loss: 0.55581
[Evaluate]  dev score: 0.75344, dev loss: 0.50104
[Train] epoch: 0/3, step: 90/588, loss: 0.37688
[Evaluate]  dev score: 0.76296, dev loss: 0.49015
[Evaluate] best accuracy performence has been updated: 0.75760 --> 0.76296
[Train] epoch: 0/3, step: 100/588, loss: 0.54712
[Evaluate]  dev score: 0.77976, dev loss: 0.47841
[Evaluate] best accuracy performence has been updated: 0.76296 --> 0.77976
[Train] epoch: 0/3, step: 110/588, loss: 0.43581
[Evaluate]  dev score: 0.78328, dev loss: 0.45888
[Evaluate] best accuracy performence has been updated: 0.77976 --> 0.78328
[Train] epoch: 0/3, step: 120/588, loss: 0.47975
[Evaluate]  dev score: 0.80184, dev loss: 0.43699
[Evaluate] best accuracy performence has been updated: 0.78328 --> 0.80184
[Train] epoch: 0/3, step: 130/588, loss: 0.43234
[Evaluate]  dev score: 0.80360, dev loss: 0.43270
[Evaluate] best accuracy performence has been updated: 0.80184 --> 0.80360
[Train] epoch: 0/3, step: 140/588, loss: 0.47868
[Evaluate]  dev score: 0.81344, dev loss: 0.41856
[Evaluate] best accuracy performence has been updated: 0.80360 --> 0.81344
[Train] epoch: 0/3, step: 150/588, loss: 0.50377
[Evaluate]  dev score: 0.81112, dev loss: 0.41700
[Train] epoch: 0/3, step: 160/588, loss: 0.43394
[Evaluate]  dev score: 0.81464, dev loss: 0.41631
[Evaluate] best accuracy performence has been updated: 0.81344 --> 0.81464
[Train] epoch: 0/3, step: 170/588, loss: 0.35849
[Evaluate]  dev score: 0.82304, dev loss: 0.39735
[Evaluate] best accuracy performence has been updated: 0.81464 --> 0.82304
[Train] epoch: 0/3, step: 180/588, loss: 0.42771
[Evaluate]  dev score: 0.82632, dev loss: 0.39682
[Evaluate] best accuracy performence has been updated: 0.82304 --> 0.82632
[Train] epoch: 0/3, step: 190/588, loss: 0.33019
[Evaluate]  dev score: 0.82728, dev loss: 0.38676
[Evaluate] best accuracy performence has been updated: 0.82632 --> 0.82728
[Train] epoch: 1/3, step: 200/588, loss: 0.23055
[Evaluate]  dev score: 0.81368, dev loss: 0.44229
[Train] epoch: 1/3, step: 210/588, loss: 0.35249
[Evaluate]  dev score: 0.77624, dev loss: 0.47421
[Train] epoch: 1/3, step: 220/588, loss: 0.28121
[Evaluate]  dev score: 0.79648, dev loss: 0.47339
[Train] epoch: 1/3, step: 230/588, loss: 0.20005
[Evaluate]  dev score: 0.80088, dev loss: 0.48150
[Train] epoch: 1/3, step: 240/588, loss: 0.23483
[Evaluate]  dev score: 0.82944, dev loss: 0.38803
[Evaluate] best accuracy performence has been updated: 0.82728 --> 0.82944
[Train] epoch: 1/3, step: 250/588, loss: 0.21045
[Evaluate]  dev score: 0.83456, dev loss: 0.39394
[Evaluate] best accuracy performence has been updated: 0.82944 --> 0.83456
[Train] epoch: 1/3, step: 260/588, loss: 0.28375
[Evaluate]  dev score: 0.83160, dev loss: 0.39078
[Train] epoch: 1/3, step: 270/588, loss: 0.21295
[Evaluate]  dev score: 0.83328, dev loss: 0.39633
[Train] epoch: 1/3, step: 280/588, loss: 0.30178
[Evaluate]  dev score: 0.82760, dev loss: 0.39227
[Train] epoch: 1/3, step: 290/588, loss: 0.34948
[Evaluate]  dev score: 0.83824, dev loss: 0.38314
[Evaluate] best accuracy performence has been updated: 0.83456 --> 0.83824
[Train] epoch: 1/3, step: 300/588, loss: 0.35221
[Evaluate]  dev score: 0.83768, dev loss: 0.37990
[Train] epoch: 1/3, step: 310/588, loss: 0.19819
[Evaluate]  dev score: 0.84072, dev loss: 0.38820
[Evaluate] best accuracy performence has been updated: 0.83824 --> 0.84072
[Train] epoch: 1/3, step: 320/588, loss: 0.28313
[Evaluate]  dev score: 0.84000, dev loss: 0.36614
[Train] epoch: 1/3, step: 330/588, loss: 0.23220
[Evaluate]  dev score: 0.84360, dev loss: 0.36771
[Evaluate] best accuracy performence has been updated: 0.84072 --> 0.84360
[Train] epoch: 1/3, step: 340/588, loss: 0.24448
[Evaluate]  dev score: 0.84496, dev loss: 0.36352
[Evaluate] best accuracy performence has been updated: 0.84360 --> 0.84496
[Train] epoch: 1/3, step: 350/588, loss: 0.28082
[Evaluate]  dev score: 0.83952, dev loss: 0.36507
[Train] epoch: 1/3, step: 360/588, loss: 0.22924
[Evaluate]  dev score: 0.84512, dev loss: 0.35552
[Evaluate] best accuracy performence has been updated: 0.84496 --> 0.84512
[Train] epoch: 1/3, step: 370/588, loss: 0.23646
[Evaluate]  dev score: 0.84760, dev loss: 0.36046
[Evaluate] best accuracy performence has been updated: 0.84512 --> 0.84760
[Train] epoch: 1/3, step: 380/588, loss: 0.22302
[Evaluate]  dev score: 0.85168, dev loss: 0.34559
[Evaluate] best accuracy performence has been updated: 0.84760 --> 0.85168
[Train] epoch: 1/3, step: 390/588, loss: 0.29414
[Evaluate]  dev score: 0.83744, dev loss: 0.38188
[Train] epoch: 2/3, step: 400/588, loss: 0.16453
[Evaluate]  dev score: 0.84904, dev loss: 0.45065
[Train] epoch: 2/3, step: 410/588, loss: 0.16047
[Evaluate]  dev score: 0.82440, dev loss: 0.52331
[Train] epoch: 2/3, step: 420/588, loss: 0.14153
[Evaluate]  dev score: 0.84792, dev loss: 0.38296
[Train] epoch: 2/3, step: 430/588, loss: 0.15215
[Evaluate]  dev score: 0.83800, dev loss: 0.47673
[Train] epoch: 2/3, step: 440/588, loss: 0.14156
[Evaluate]  dev score: 0.83712, dev loss: 0.43629
[Train] epoch: 2/3, step: 450/588, loss: 0.09301
[Evaluate]  dev score: 0.84248, dev loss: 0.39954
[Train] epoch: 2/3, step: 460/588, loss: 0.12972
[Evaluate]  dev score: 0.84456, dev loss: 0.45264
[Train] epoch: 2/3, step: 470/588, loss: 0.11309
[Evaluate]  dev score: 0.84696, dev loss: 0.39900
[Train] epoch: 2/3, step: 480/588, loss: 0.15899
[Evaluate]  dev score: 0.83232, dev loss: 0.48246
[Train] epoch: 2/3, step: 490/588, loss: 0.16320
[Evaluate]  dev score: 0.84584, dev loss: 0.40335
[Train] epoch: 2/3, step: 500/588, loss: 0.12657
[Evaluate]  dev score: 0.84744, dev loss: 0.37349
[Train] epoch: 2/3, step: 510/588, loss: 0.07886
[Evaluate]  dev score: 0.84400, dev loss: 0.42932
[Train] epoch: 2/3, step: 520/588, loss: 0.11325
[Evaluate]  dev score: 0.84720, dev loss: 0.41804
[Train] epoch: 2/3, step: 530/588, loss: 0.13167
[Evaluate]  dev score: 0.83648, dev loss: 0.47225
[Train] epoch: 2/3, step: 540/588, loss: 0.06612
[Evaluate]  dev score: 0.84840, dev loss: 0.45871
[Train] epoch: 2/3, step: 550/588, loss: 0.18083
[Evaluate]  dev score: 0.85000, dev loss: 0.44488
[Train] epoch: 2/3, step: 560/588, loss: 0.11420
[Evaluate]  dev score: 0.85088, dev loss: 0.40565
[Train] epoch: 2/3, step: 570/588, loss: 0.10456
[Evaluate]  dev score: 0.83912, dev loss: 0.44197
[Train] epoch: 2/3, step: 580/588, loss: 0.14933
[Evaluate]  dev score: 0.84736, dev loss: 0.41775
[Evaluate]  dev score: 0.85200, dev loss: 0.39351
[Evaluate] best accuracy performence has been updated: 0.85168 --> 0.85200
None
total_acc:
[55.0625, 64.0, 57.0, 73.0, 65.0, 64.90625, 66.4375, 64.375, 63.53125, 63.859375, 65.3125, 64.71875, 63.5625, 64.0, 65.21875, 66.296875, 64.546875, 63.21875, 66.078125, 64.21875, 62.59375, 63.84375, 61.875, 64.125, 63.53125, 64.125, 60.484375, 62.125, 64.0, 62.5, 64.796875, 64.84375, 61.75, 59.9375, 63.25, 64.125, 64.65625, 65.203125, 65.875, 65.1875, 62.6875, 64.71875, 63.5625, 64.5, 63.875, 64.1875, 64.9375, 64.0, 63.3125, 63.65625, 64.875, 63.671875, 66.953125, 63.15625, 63.671875, 64.0, 63.6875, 63.8125, 63.6875, 63.8125, 64.703125, 64.796875, 62.25, 63.90625, 64.71875, 63.9375, 63.71875, 64.09375, 64.328125, 63.1875, 65.078125, 64.0, 64.0625, 64.0, 62.359375, 64.421875, 64.4375, 65.125, 64.796875, 64.75, 64.90625, 64.28125, 64.09375, 64.6875, 63.71875, 64.78125, 64.09375, 64.0, 64.109375, 63.53125, 64.328125, 65.125, 62.359375, 63.84375, 64.46875, 64.375, 61.84375, 65.203125, 64.28125, 63.78125, 62.59375, 64.78125, 63.671875, 64.125, 67.046875, 67.609375, 61.1875, 63.890625, 65.09375, 64.703125, 64.84375, 64.0, 67.90625, 63.75, 64.0, 63.90625, 66.4375, 63.25, 64.0625, 62.25, 63.65625, 63.6875, 64.0, 65.375, 61.609375, 63.71875, 65.171875, 66.4375, 63.5625, 63.90625, 64.25, 64.0, 64.3125, 67.046875, 63.90625, 63.578125, 62.453125, 63.765625, 64.0, 63.4375, 63.625, 63.90625, 64.75, 64.0, 64.5625, 63.5625, 64.3125, 63.859375, 65.375, 64.4375, 64.0, 64.59375, 64.0, 63.125, 63.53125, 62.375, 64.0, 64.109375, 64.0625, 63.25, 63.375, 64.09375, 64.765625, 64.0, 63.90625, 63.9375, 63.953125, 64.28125, 64.8125, 64.0, 64.234375, 64.15625, 64.28125, 64.0, 64.0, 64.140625, 64.234375, 64.5625, 63.921875, 64.46875, 68.125, 64.1875, 64.0, 64.390625, 66.03125, 64.0, 64.046875, 64.25, 64.28125, 64.859375, 64.09375, 67.859375, 64.0625, 64.15625, 64.9375, 20.2, 65.03125, 65.203125, 64.65625, 63.90625, 64.046875, 64.46875, 63.65625, 64.0625, 64.4375, 64.0, 64.0625, 63.9375, 63.625, 64.171875, 63.0625, 63.625, 64.15625, 65.09375, 63.90625, 63.234375, 64.28125, 64.328125, 63.90625, 64.0625, 63.921875, 64.859375, 65.265625, 63.96875, 65.6875, 64.421875, 63.75, 64.390625, 64.0, 64.03125, 64.75, 63.875, 66.4375, 64.09375, 64.078125, 63.625, 63.84375, 62.984375, 64.109375, 65.21875, 65.25, 67.0625, 64.21875, 64.5625, 63.625, 64.0625, 65.546875, 64.515625, 64.0, 64.625, 64.375, 64.09375, 64.234375, 64.625, 64.03125, 64.65625, 66.03125, 63.9375, 64.375, 64.046875, 64.3125, 64.4375, 63.96875, 64.1875, 64.625, 64.03125, 64.0, 64.0, 64.609375, 64.5, 64.09375, 64.046875, 63.9375, 64.0, 64.328125, 63.84375, 64.0, 64.0625, 66.03125, 63.6875, 64.0, 66.40625, 64.1875, 64.765625, 66.125, 65.0, 63.90625, 63.6875, 64.0625, 64.84375, 64.0, 64.15625, 63.6875, 66.34375, 65.59375, 64.5625, 63.84375, 65.03125, 65.546875, 64.375, 63.53125, 64.875, 64.03125, 64.09375, 64.5, 64.65625, 65.40625, 63.890625, 64.0625, 64.0625, 64.09375, 64.78125, 65.125, 64.09375, 63.8125, 64.09375, 64.0, 64.46875, 63.90625, 64.0, 65.53125, 63.5, 64.875, 64.078125, 63.875, 64.03125, 64.625, 64.1875, 64.46875, 63.96875, 64.21875, 63.625, 64.28125, 65.6875, 66.4375, 64.703125, 65.25, 63.890625, 65.203125, 64.25, 64.0, 64.328125, 64.0, 63.875, 66.75, 63.484375, 63.6875, 64.0, 65.40625, 64.0, 64.75, 63.453125, 65.75, 64.4375, 64.015625, 65.03125, 63.625, 64.5625, 65.40625, 67.4375, 64.78125, 63.765625, 63.90625, 64.0, 66.4375, 64.03125, 64.859375, 65.3125, 65.0, 64.125, 64.0, 65.546875, 64.171875, 64.375, 63.53125, 63.9375, 63.5625, 64.4375, 65.015625, 65.21875, 64.1875, 64.875, 65.25, 67.9375, 64.46875, 64.046875, 64.328125, 64.0, 65.3125, 64.09375, 64.28125, 20.1, 65.546875, 63.984375, 64.03125, 64.25, 64.5625, 64.0, 64.25, 65.03125, 64.0, 63.96875, 64.875, 64.328125, 63.875, 63.96875, 64.140625, 64.078125, 64.1875, 64.234375, 63.90625, 64.5625, 64.015625, 64.28125, 64.65625, 64.390625, 63.984375, 64.46875, 64.0, 64.0625, 64.3125, 64.0, 64.09375, 64.3125, 64.28125, 64.984375, 64.09375, 64.234375, 65.0, 66.640625, 64.0, 64.75, 64.234375, 64.46875, 65.5625, 63.921875, 63.96875, 64.1875, 64.9375, 63.9375, 64.1875, 64.5, 64.5625, 64.140625, 63.9375, 64.125, 63.9375, 64.328125, 65.3125, 64.75, 64.3125, 64.21875, 64.390625, 64.0, 64.0, 64.0, 65.265625, 64.0, 64.5, 64.0, 64.078125, 63.578125, 64.5625, 64.5, 64.15625, 64.75, 64.0, 64.09375, 64.65625, 64.875, 64.09375, 65.265625, 64.0625, 64.46875, 64.375, 64.0, 64.234375, 64.46875, 63.8125, 63.96875, 64.0, 65.375, 64.703125, 64.125, 64.0625, 64.9375, 65.03125, 65.875, 64.0625, 66.40625, 64.03125, 64.390625, 64.03125, 63.921875, 64.09375, 64.125, 64.28125, 64.5, 64.21875, 63.9375, 64.09375, 66.40625, 63.90625, 65.3125, 65.625, 64.75, 63.84375, 64.09375, 64.0, 63.921875, 64.1875, 63.90625, 63.71875, 64.0625, 64.765625, 64.421875, 64.5625, 64.0625, 64.078125, 64.65625, 64.234375, 64.15625, 64.4375, 64.0, 64.0, 64.125, 64.875, 64.75, 64.1875, 64.0, 64.546875, 64.0, 64.765625, 64.28125, 64.09375, 64.75, 64.65625, 64.09375, 64.0625, 63.859375, 64.875, 64.375, 64.0, 64.0, 64.03125, 64.0, 64.09375, 64.125, 64.984375, 65.71875, 64.09375, 64.15625, 63.9375, 64.046875, 65.375, 64.0, 64.5625, 64.5625, 64.0625, 64.0, 64.0, 65.625, 66.75, 64.3125, 64.03125, 64.046875, 66.25, 64.0, 64.28125, 64.046875, 64.0, 67.984375, 64.125, 64.875, 64.15625, 64.625, 64.0625, 64.5625, 64.0, 64.0625, 64.46875, 64.984375, 67.28125, 64.015625, 64.46875, 64.3125, 63.96875, 20.4]
total_loss:
[0.6943915486335754, 0.6993610262870789, 0.7127633094787598, 0.6813052296638489, 0.6924631595611572, 0.6851255297660828, 0.6842765808105469, 0.6782393455505371, 0.6786357760429382, 0.6764721870422363, 0.673241376876831, 0.6724133491516113, 0.6724430918693542, 0.6723741292953491, 0.6557652950286865, 0.6482126712799072, 0.6595596075057983, 0.6645523905754089, 0.6277068853378296, 0.6348437070846558, 0.623883068561554, 0.6227344274520874, 0.5495951771736145, 0.5716017484664917, 0.5402728319168091, 0.6280001401901245, 0.6910073757171631, 0.6717156171798706, 0.6119428873062134, 0.5524896383285522, 0.5776579976081848, 0.6260774731636047, 0.5997766256332397, 0.6466356515884399, 0.594511866569519, 0.6116319298744202, 0.6168754696846008, 0.609453558921814, 0.5637308955192566, 0.5885592699050903, 0.5635141134262085, 0.5740407109260559, 0.6217963099479675, 0.583730936050415, 0.5617793202400208, 0.5224241614341736, 0.5776565670967102, 0.5801581144332886, 0.536981463432312, 0.5588958263397217, 0.5531586408615112, 0.5467227101325989, 0.5375209450721741, 0.644782543182373, 0.570323646068573, 0.5109895467758179, 0.5382441878318787, 0.48346051573753357, 0.5262088775634766, 0.49927887320518494, 0.5379536747932434, 0.5100521445274353, 0.5327979922294617, 0.46563243865966797, 0.5264086127281189, 0.5187122821807861, 0.5200984477996826, 0.5168130397796631, 0.5389388799667358, 0.5585096478462219, 0.4645587205886841, 0.5629595518112183, 0.46700936555862427, 0.4435116946697235, 0.5792730450630188, 0.5157679915428162, 0.5325619578361511, 0.5278496742248535, 0.47488728165626526, 0.4456428289413452, 0.5558100938796997, 0.48126256465911865, 0.4358850419521332, 0.42088451981544495, 0.37736839056015015, 0.46744561195373535, 0.5490957498550415, 0.5053566098213196, 0.4331998825073242, 0.4331425428390503, 0.3768846392631531, 0.4275078773498535, 0.5145638585090637, 0.46293821930885315, 0.43597185611724854, 0.47772765159606934, 0.5448265075683594, 0.48422205448150635, 0.4337193965911865, 0.44109806418418884, 0.5471171736717224, 0.5567479729652405, 0.4714142382144928, 0.4771653115749359, 0.3908367156982422, 0.4722425937652588, 0.6153281331062317, 0.49215635657310486, 0.4557177424430847, 0.49286970496177673, 0.43580517172813416, 0.45484307408332825, 0.39425742626190186, 0.4393540620803833, 0.4048364758491516, 0.37105515599250793, 0.513683021068573, 0.46175602078437805, 0.4219912588596344, 0.5448242425918579, 0.4797542095184326, 0.37451234459877014, 0.42795664072036743, 0.4094352424144745, 0.588566243648529, 0.44727441668510437, 0.44870391488075256, 0.40443992614746094, 0.4171956777572632, 0.4257535934448242, 0.43233948945999146, 0.3880384564399719, 0.4534030556678772, 0.40636909008026123, 0.4147375822067261, 0.459685742855072, 0.5216967463493347, 0.43437057733535767, 0.4282172620296478, 0.4429057538509369, 0.4786832928657532, 0.3599073886871338, 0.42016884684562683, 0.44229328632354736, 0.40641865134239197, 0.44980984926223755, 0.3605503737926483, 0.3704741597175598, 0.38912540674209595, 0.4708927869796753, 0.5037708282470703, 0.5041000247001648, 0.4051628112792969, 0.520237922668457, 0.45073091983795166, 0.41616931557655334, 0.43762871623039246, 0.3921912908554077, 0.4844211935997009, 0.4317399859428406, 0.43394118547439575, 0.4312129318714142, 0.3876517713069916, 0.38934630155563354, 0.3848773241043091, 0.4198549687862396, 0.4411366879940033, 0.36181166768074036, 0.4525397717952728, 0.38677671551704407, 0.35848960280418396, 0.28507280349731445, 0.39131200313568115, 0.3695620596408844, 0.37382447719573975, 0.49105846881866455, 0.47744885087013245, 0.4969979226589203, 0.3236033022403717, 0.4748993515968323, 0.4277082681655884, 0.4363641142845154, 0.33172959089279175, 0.4057809114456177, 0.3484705686569214, 0.39831477403640747, 0.4530254900455475, 0.4434692859649658, 0.4075217843055725, 0.37662479281425476, 0.3301897943019867, 0.28077003359794617, 0.30867862701416016, 0.3982459604740143, 0.3307804763317108, 0.46978527307510376, 0.2664577066898346, 0.2981843054294586, 0.3529570400714874, 0.23200871050357819, 0.23055030405521393, 0.24623888731002808, 0.4577150344848633, 0.31813108921051025, 0.3118295669555664, 0.31518882513046265, 0.35645976662635803, 0.27397194504737854, 0.4028424322605133, 0.29441970586776733, 0.3524872064590454, 0.2899414002895355, 0.32071372866630554, 0.29438674449920654, 0.24783654510974884, 0.37230193614959717, 0.27613767981529236, 0.3041532635688782, 0.3017820119857788, 0.2542940080165863, 0.2812143564224243, 0.24169544875621796, 0.2390182614326477, 0.3242901563644409, 0.3367209732532501, 0.23118774592876434, 0.22630834579467773, 0.21295084059238434, 0.2669932544231415, 0.2734189033508301, 0.20005464553833008, 0.26246941089630127, 0.33630886673927307, 0.2077443152666092, 0.2669682800769806, 0.2806253433227539, 0.30806875228881836, 0.40654343366622925, 0.26848089694976807, 0.2653353214263916, 0.23482665419578552, 0.40282097458839417, 0.22627484798431396, 0.2841113209724426, 0.32915735244750977, 0.2174719125032425, 0.26090022921562195, 0.3371308147907257, 0.25197523832321167, 0.22241801023483276, 0.21045389771461487, 0.26855990290641785, 0.38071563839912415, 0.4156145751476288, 0.32554441690444946, 0.3412627875804901, 0.22888199985027313, 0.19006821513175964, 0.270155131816864, 0.2669001817703247, 0.28375086188316345, 0.22209949791431427, 0.24153836071491241, 0.2594282329082489, 0.32828909158706665, 0.2815053462982178, 0.28968098759651184, 0.23906922340393066, 0.26405999064445496, 0.26703351736068726, 0.21294927597045898, 0.2749083638191223, 0.298957884311676, 0.15690761804580688, 0.24060873687267303, 0.2465575933456421, 0.34983038902282715, 0.3242475986480713, 0.264189749956131, 0.30615559220314026, 0.3017803430557251, 0.24255651235580444, 0.26309734582901, 0.27108606696128845, 0.30466964840888977, 0.3108692467212677, 0.30708590149879456, 0.2139092981815338, 0.3360726833343506, 0.23802971839904785, 0.3494775593280792, 0.30704474449157715, 0.31251031160354614, 0.32494693994522095, 0.23114308714866638, 0.2889900803565979, 0.2886987328529358, 0.211409330368042, 0.22976598143577576, 0.2069619596004486, 0.35221046209335327, 0.17956379055976868, 0.2857309877872467, 0.3212275803089142, 0.24294598400592804, 0.21370148658752441, 0.2196923792362213, 0.2479819804430008, 0.2677382230758667, 0.1983041763305664, 0.19819390773773193, 0.2640829384326935, 0.36069920659065247, 0.375033974647522, 0.2258022129535675, 0.3006584942340851, 0.1990492343902588, 0.3214767575263977, 0.22926782071590424, 0.3084055781364441, 0.28313490748405457, 0.30342385172843933, 0.27039381861686707, 0.2632031738758087, 0.2627478241920471, 0.23889018595218658, 0.27797383069992065, 0.21733494102954865, 0.23388846218585968, 0.24981394410133362, 0.23219971358776093, 0.30890893936157227, 0.27560940384864807, 0.22980327904224396, 0.28115198016166687, 0.20702862739562988, 0.29715612530708313, 0.3376084566116333, 0.19026809930801392, 0.2480972856283188, 0.24447712302207947, 0.19947084784507751, 0.326020747423172, 0.2366999089717865, 0.24951116740703583, 0.36075490713119507, 0.27457791566848755, 0.40599262714385986, 0.28851598501205444, 0.3297310769557953, 0.2808226943016052, 0.25778645277023315, 0.2749165892601013, 0.20795832574367523, 0.2050067037343979, 0.33679527044296265, 0.3420022130012512, 0.257743239402771, 0.2782224416732788, 0.2838687300682068, 0.2292448878288269, 0.26364073157310486, 0.26619553565979004, 0.29315632581710815, 0.3035784363746643, 0.2141512781381607, 0.3282882273197174, 0.21012544631958008, 0.222364142537117, 0.29143092036247253, 0.23646292090415955, 0.30905264616012573, 0.2757420241832733, 0.20720146596431732, 0.30383965373039246, 0.22050540149211884, 0.2762262523174286, 0.24349869787693024, 0.22128799557685852, 0.2961779534816742, 0.2230181097984314, 0.20456653833389282, 0.23538266122341156, 0.28374382853507996, 0.2714492678642273, 0.24485592544078827, 0.22901377081871033, 0.20251567661762238, 0.31589406728744507, 0.2098514586687088, 0.2941402792930603, 0.2699771821498871, 0.10196541249752045, 0.10331010073423386, 0.1472424864768982, 0.14914631843566895, 0.10538753867149353, 0.20749697089195251, 0.13308851420879364, 0.12739259004592896, 0.1645316779613495, 0.12536416947841644, 0.12431269139051437, 0.15563209354877472, 0.18244291841983795, 0.053784821182489395, 0.1842142790555954, 0.0954180434346199, 0.08375190198421478, 0.10974739491939545, 0.1604708880186081, 0.11376488953828812, 0.14249102771282196, 0.18592172861099243, 0.13200746476650238, 0.12503722310066223, 0.16489890217781067, 0.19017833471298218, 0.11632364988327026, 0.07782582193613052, 0.14152522385120392, 0.11871480941772461, 0.12652131915092468, 0.08263842761516571, 0.0954023152589798, 0.11043276637792587, 0.17374341189861298, 0.07736188173294067, 0.08607854694128036, 0.06519871205091476, 0.15214590728282928, 0.09939666837453842, 0.20828251540660858, 0.1122903972864151, 0.10518715530633926, 0.10364483296871185, 0.0888824388384819, 0.16788926720619202, 0.12190516293048859, 0.19054095447063446, 0.1415628343820572, 0.12506000697612762, 0.12927792966365814, 0.1276991367340088, 0.11667753010988235, 0.14328062534332275, 0.1209656372666359, 0.13751041889190674, 0.15307550132274628, 0.2226107269525528, 0.09300918132066727, 0.13163457810878754, 0.16939009726047516, 0.11636869609355927, 0.12104861438274384, 0.11350925266742706, 0.0814228430390358, 0.08649794012308121, 0.13132666051387787, 0.19938480854034424, 0.12972430884838104, 0.2081809788942337, 0.07241456210613251, 0.17524230480194092, 0.1427036076784134, 0.12482038885354996, 0.14021308720111847, 0.14369206130504608, 0.08918442577123642, 0.07568088918924332, 0.11308761686086655, 0.11914854496717453, 0.15465249121189117, 0.1487438976764679, 0.14255908131599426, 0.14872989058494568, 0.11034630984067917, 0.046497609466314316, 0.14266593754291534, 0.11685121059417725, 0.1589852124452591, 0.08100562542676926, 0.13653919100761414, 0.14092466235160828, 0.1102699339389801, 0.1461315006017685, 0.13462047278881073, 0.16466115415096283, 0.10432116687297821, 0.1498996764421463, 0.16319507360458374, 0.09172829985618591, 0.23075996339321136, 0.16443806886672974, 0.1369490772485733, 0.13219432532787323, 0.20098045468330383, 0.14520789682865143, 0.12620550394058228, 0.14413867890834808, 0.12657448649406433, 0.12721450626850128, 0.13285279273986816, 0.15710030496120453, 0.1260959357023239, 0.19597691297531128, 0.21109670400619507, 0.12157267332077026, 0.12539079785346985, 0.1463233381509781, 0.07885544002056122, 0.17815549671649933, 0.19614756107330322, 0.1085124984383583, 0.13020215928554535, 0.17787469923496246, 0.15290048718452454, 0.11377265304327011, 0.13724493980407715, 0.09921949356794357, 0.11325183510780334, 0.19770991802215576, 0.16256998479366302, 0.12071522325277328, 0.18576382100582123, 0.11945793777704239, 0.1288011074066162, 0.10797305405139923, 0.10777317732572556, 0.09043212234973907, 0.1316693127155304, 0.09446817636489868, 0.07271338254213333, 0.07633443921804428, 0.1140422523021698, 0.10654111206531525, 0.09921576082706451, 0.13655686378479004, 0.10209924727678299, 0.10758928954601288, 0.06611897051334381, 0.07027660310268402, 0.13228365778923035, 0.15609034895896912, 0.12019796669483185, 0.1156245544552803, 0.08664646744728088, 0.1011698991060257, 0.15721474587917328, 0.0936993733048439, 0.18083126842975616, 0.08405683934688568, 0.17655833065509796, 0.08250198513269424, 0.09193230420351028, 0.1465069204568863, 0.1092522069811821, 0.1442887783050537, 0.12498398125171661, 0.0928950309753418, 0.11420281231403351, 0.13744033873081207, 0.1562260538339615, 0.1252233386039734, 0.1644386500120163, 0.09531384706497192, 0.04119228571653366, 0.10170277208089828, 0.15154117345809937, 0.07131844013929367, 0.10455882549285889, 0.09864237904548645, 0.07277008146047592, 0.24502356350421906, 0.14305832982063293, 0.10688097029924393, 0.08796417713165283, 0.12006781995296478, 0.13360360264778137, 0.08783924579620361, 0.14933158457279205, 0.10813496261835098, 0.1517849862575531, 0.21454975008964539, 0.09642459452152252, 0.11067715287208557, 0.1340065896511078, 0.08759596198797226]
[Train] Training done!
time:  1372.6602189540863

Evaluate on test set, Accuracy: 0.84664
Label:  积极情绪

Process finished with exit code 0
 

 

 时间太晚了,明天再分析

特别鸣谢: 实践:基于双向LSTM模型完成文本分类任务-CSDN博客

 困扰我一天的错终于解决了

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/118753
推荐阅读
相关标签
  

闽ICP备14008679号