当前位置:   article > 正文

NLP之文本分类_nlp文本分类

nlp文本分类
注意:本实验在百度平台的Al Studio运行

1.什么是文本分类

文本分类就是根据文本内容将文本划分到不同类别,例如新闻系统中,每篇新闻报道会划归到不同的类别。

2.文本分类的应用

  • 内容分类(新闻分类)

  • 邮件过滤(例如垃圾邮件过滤)

  • 用户分类(如商城消费级别、喜好)

  • 评论、文章、对话的情感分类(正面、负面、中性)

3.文本分类案例

  • 任务:建立文本分类模型,并对模型进行训练、评估,从而实现对中文新闻摘要类别正确划分

  • 数据集:从网站上爬取56821条数据中文新闻摘要,包含10种类别,国际、文化、娱乐、体育、财经、汽车、教育、科技、房产、证券,各类别样本数量如下表所示:

  • 模型选择:

  • 步骤:

  • 代码

    【预处理部分】

    1. ########################### 数据预处理 #########################
    2. import os
    3. from multiprocessing import cpu_count
    4. import numpy as np
    5. import paddle
    6. import paddle.fluid as fluid
    7. # 定义一组公共变量
    8. data_root = "data/" # 数据集所在目录
    9. data_file = "news_classify_data.txt" # 原始数据集
    10. train_file = "train.txt" # 训练集文件
    11. test_file = "test.txt" # 测试集文件
    12. dict_file = "dict_txt.txt" # 字典文件(存放字和编码映射关系)
    13. data_file_path = data_root + data_file # 数据集完整路径
    14. train_file_path = data_root + train_file # 训练集文件完整路径
    15. test_file_path = data_root + test_file # 测试集文件完整路径
    16. dict_file_path = data_root + dict_file # 字典文件完整路径
    17. # 取出样本中所有字,对每个字进行编码,将编码结果存入字典文件
    18. def create_dict():
    19. dict_set = set() # 集合,用作去重
    20. with open(data_file_path, "r", encoding="utf-8") as f:
    21. for line in f.readlines(): # 遍历每行
    22. line = line.replace("\n", "") # 去除换行符
    23. tmp_list = line.split("_!_") # 根据分隔符拆分
    24. title = tmp_list[-1] # 最后一个字段即为标题
    25. for word in title: # 取出每个字
    26. dict_set.add(word)
    27. # 遍历集合,取出每个字进行编号
    28. dict_txt = {} # 定义字典
    29. i = 1 # 编码使用的计数器
    30. for word in dict_set:
    31. dict_txt[word] = i # 字-编码 键值对添加到字典
    32. i += 1
    33. dict_txt["<unk>"] = i # 未知字符(在样本中未出现过的字)
    34. # 将字典内容存入文件
    35. with open(dict_file_path, "w", encoding="utf-8") as f:
    36. f.write(str(dict_txt))
    37. print("生成字典结束.")
    38. # 传入一个句子,将每个字替换为编码值,和标签一起返回
    39. def line_encoding(title, dict_txt, label):
    40. new_line = "" # 编码结果
    41. for word in title:
    42. if word in dict_txt: # 在字典中
    43. code = str(dict_txt[word]) # 取出编码值
    44. else: # 不在字典中
    45. code = str(dict_txt["<unk>"]) # 取未知字符编码值
    46. new_line = new_line + code + "," # 追加到字符串后面
    47. new_line = new_line[:-1] # 去掉最后一个多余的逗号
    48. new_line = new_line + "\t" + label + "\n" # 追加标签值
    49. return new_line
    50. # 读取原始样本,取出标题部分进行编码,将编码后的划分测试集/训练集
    51. def create_train_test_file():
    52. # 清空训练集/测试集
    53. with open(train_file_path, "w") as f:
    54. pass
    55. with open(test_file_path, "w") as f:
    56. pass
    57. # 读取字典文件
    58. with open(dict_file_path, "r", encoding="utf-8") as f_dict:
    59. dict_txt = eval(f_dict.readlines()[0]) # 读取字典文件第一行,生成字典对象
    60. # 读取原始样本
    61. with open(data_file_path, "r", encoding="utf-8") as f_data:
    62. lines = f_data.readlines()
    63. i = 0
    64. for line in lines:
    65. tmp_list = line.replace("\n", "").split("_!_") # 拆分
    66. title = tmp_list[3] # 标题
    67. label = tmp_list[1] # 类别
    68. new_line = line_encoding(title, dict_txt, label) # 对标题编码
    69. if i % 10 == 0: # 写入测试集
    70. with open(test_file_path, "a", encoding="utf-8") as f:
    71. f.write(new_line)
    72. else: # 写入训练集
    73. with open(train_file_path, "a", encoding="utf-8") as f:
    74. f.write(new_line)
    75. i += 1
    76. print("生成训练集/测试集结束.")
    77. create_dict() # 根据样本生成字典
    78. create_train_test_file()

    输出:

    1. 生成字典结束.
    2. 生成训练集/测试集结束.

    【模型定义与训练】

    1. paddle.enable_static()
    2. # 读取字典文件,返回字典长度
    3. def get_dict_len(dict_path):
    4. with open(dict_path, "r", encoding="utf-8") as f:
    5. dict_txt = eval(f.readlines()[0])
    6. return len(dict_txt.keys())
    7. def data_mapper(sample):
    8. data, label = sample # 赋值到变量
    9. val = [int(w) for w in data.split(",")] # 将编码值转换位数字(从文件读取为字符串)
    10. return val, int(label)
    11. def train_reader(train_file_path): # 训练集读取器
    12. def reader():
    13. with open(train_file_path, "r") as f:
    14. lines = f.readlines()
    15. np.random.shuffle(lines) # 随机化处理
    16. for line in lines:
    17. data, label = line.split("\t") # 拆分
    18. yield data, label
    19. return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
    20. def test_reader(test_file_path): # 训练集读取器
    21. def reader():
    22. with open(test_file_path, "r") as f:
    23. lines = f.readlines()
    24. for line in lines:
    25. data, label = line.split("\t") # 拆分
    26. yield data, label
    27. return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
    28. # 定义网络
    29. def Text_CNN(data, dict_dim, class_dim=10, emb_dim=128,
    30. hid_dim=128, hid_dim2=128):
    31. """
    32. 定义TextCNN模型
    33. :param data: 输入
    34. :param dict_dim: 词典大小(词语总的数量)
    35. :param class_dim: 分类的数量
    36. :param emb_dim: 词嵌入长度
    37. :param hid_dim: 第一个卷基层卷积核数量
    38. :param hid_dim2: 第二个卷基层卷积核数量
    39. :return: 模型预测结果
    40. """
    41. # embedding层
    42. emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
    43. # 并列两个卷积/池化层
    44. conv1 = fluid.nets.sequence_conv_pool(input=emb, # 输入(词嵌入层输出)
    45. num_filters=hid_dim,# 卷积核数量
    46. filter_size=3,#卷积核大小
    47. act="tanh",#激活函数
    48. pool_type="sqrt")#池化类型
    49. conv2 = fluid.nets.sequence_conv_pool(input=emb, # 输入(词嵌入层输出)
    50. num_filters=hid_dim2,# 卷积核数量
    51. filter_size=4,#卷积核大小
    52. act="tanh",#激活函数
    53. pool_type="sqrt")#池化类型
    54. # fc
    55. output = fluid.layers.fc(input=[conv1, conv2], # 输入
    56. size=class_dim,#输出值个数
    57. act="softmax")#激活函数
    58. return output
    59. # 定义占位符张量
    60. words = fluid.layers.data(name="words",
    61. shape=[1],
    62. dtype="int64",
    63. lod_level=1) # LOD张量用来表示变长数据
    64. label = fluid.layers.data(name="label",
    65. shape=[1],
    66. dtype="int64")
    67. dict_dim = get_dict_len(dict_file_path) # 获取字典长度
    68. # 调用模型函数
    69. model = Text_CNN(words, dict_dim)
    70. # 损失函数
    71. cost = fluid.layers.cross_entropy(input=model, label=label)
    72. avg_cost = fluid.layers.mean(cost)
    73. # 优化器
    74. optimizer = fluid.optimizer.Adam(learning_rate=0.0001)
    75. optimizer.minimize(avg_cost)
    76. # 准确率
    77. accuracy = fluid.layers.accuracy(input=model, label=label)
    78. # 执行器
    79. place = fluid.CUDAPlace(0)
    80. exe = fluid.Executor(place)
    81. exe.run(fluid.default_startup_program())
    82. # reader
    83. ## 训练集reader
    84. tr_reader = train_reader(train_file_path)
    85. batch_train_reader = paddle.batch(tr_reader, batch_size=128)
    86. ## 测试集reader
    87. ts_reader = test_reader(test_file_path)
    88. batch_test_reader = paddle.batch(ts_reader, batch_size=128)
    89. # feeder
    90. feeder = fluid.DataFeeder(place=place, feed_list=[words, label])
    91. # 开始训练
    92. for epoch in range(80): # 外层循环控制训练轮次
    93. for batch_id, data in enumerate(batch_train_reader()): # 内层循环控制批次
    94. train_cost, train_acc = exe.run(fluid.default_main_program(),#program
    95. feed=feeder.feed(data),#喂入的参数
    96. fetch_list=[avg_cost, accuracy])#返回值
    97. if batch_id % 100 == 0:
    98. print("epoch:%d, batch:%d, cost:%f, acc:%f" %
    99. (epoch, batch_id, train_cost[0], train_acc[0]))
    100. # 每轮训练结束后进行模型评估
    101. test_costs_list = [] # 存放测试集损失值
    102. test_accs_list = [] # 存放测试集准确率
    103. for batch_id, data in enumerate(batch_test_reader()):
    104. test_cost, test_acc = exe.run(fluid.default_main_program(),
    105. feed=feeder.feed(data),
    106. fetch_list=[avg_cost, accuracy])
    107. test_costs_list.append(test_cost[0])
    108. test_accs_list.append(test_acc[0])
    109. # 计算所有批次损失值/准确率均值
    110. avg_test_cost = sum(test_costs_list) / len(test_costs_list)
    111. avg_test_acc = sum(test_accs_list) / len(test_accs_list)
    112. print("epoch:%d, test_cost:%f, test_acc:%f" %
    113. (epoch, avg_test_cost, avg_test_acc))
    114. # 训练结束,保存模型
    115. model_save_dir = "model/"
    116. if not os.path.exists(model_save_dir):
    117. os.makedirs(model_save_dir)
    118. fluid.io.save_inference_model(model_save_dir, # 保存路径
    119. feeded_var_names=[words.name],# 使用时传入参数名称
    120. target_vars=[model],#预测结果
    121. executor=exe)#执行器
    122. print("模型保存成功.")

    输出

    1. W0301 17:44:27.392561 134 gpu\_resources.cc:61\] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 11.2
    2. W0301 17:44:27.396858 134 gpu\_resources.cc:91\] device: 0, cuDNN Version: 8.2.
    3. epoch:0, batch:0, cost:2.311182, acc:0.062500
    4. epoch:0, batch:100, cost:2.063974, acc:0.398438
    5. epoch:0, batch:200, cost:1.564050, acc:0.531250
    6. epoch:0, batch:300, cost:1.113311, acc:0.726562
    7. epoch:0, test\_cost:0.984438, test\_acc:0.699619
    8. epoch:1, batch:0, cost:0.904973, acc:0.703125
    9. epoch:1, batch:100, cost:1.056342, acc:0.648438
    10. epoch:1, batch:200, cost:0.757222, acc:0.773438
    11. epoch:1, batch:300, cost:0.768483, acc:0.757812
    12. epoch:1, test\_cost:0.743878, test\_acc:0.760737
    13. epoch:2, batch:0, cost:0.799179, acc:0.773438
    14. epoch:2, batch:100, cost:0.684252, acc:0.796875
    15. epoch:2, batch:200, cost:0.700367, acc:0.773438
    16. epoch:2, batch:300, cost:0.658518, acc:0.820312
    17. epoch:2, test\_cost:0.662426, test\_acc:0.787224
    18. epoch:3, batch:0, cost:0.732371, acc:0.757812
    19. epoch:3, batch:100, cost:0.621147, acc:0.804688
    20. epoch:3, batch:200, cost:0.592039, acc:0.820312
    21. epoch:3, batch:300, cost:0.566666, acc:0.835938
    22. epoch:3, test\_cost:0.613746, test\_acc:0.799462
    23. epoch:4, batch:0, cost:0.477513, acc:0.835938
    24. epoch:4, batch:100, cost:0.510523, acc:0.820312
    25. epoch:4, batch:200, cost:0.635614, acc:0.765625
    26. epoch:4, batch:300, cost:0.487310, acc:0.820312
    27. epoch:4, test\_cost:0.579772, test\_acc:0.810488
    28. epoch:5, batch:0, cost:0.564993, acc:0.796875
    29. epoch:5, batch:100, cost:0.406841, acc:0.882812
    30. epoch:5, batch:200, cost:0.444392, acc:0.867188
    31. epoch:5, batch:300, cost:0.729928, acc:0.765625
    32. epoch:5, test\_cost:0.549828, test\_acc:0.820210
    33. epoch:6, batch:0, cost:0.676264, acc:0.757812
    34. epoch:6, batch:100, cost:0.596377, acc:0.859375
    35. epoch:6, batch:200, cost:0.532443, acc:0.773438
    36. epoch:6, batch:300, cost:0.534213, acc:0.820312
    37. epoch:6, test\_cost:0.529354, test\_acc:0.824554
    38. epoch:7, batch:0, cost:0.426611, acc:0.867188
    39. epoch:7, batch:100, cost:0.474698, acc:0.851562
    40. epoch:7, batch:200, cost:0.529180, acc:0.851562
    41. epoch:7, batch:300, cost:0.512545, acc:0.820312
    42. epoch:7, test\_cost:0.510936, test\_acc:0.830113
    43. epoch:8, batch:0, cost:0.559594, acc:0.859375
    44. epoch:8, batch:100, cost:0.670438, acc:0.812500
    45. epoch:8, batch:200, cost:0.519166, acc:0.796875
    46. epoch:8, batch:300, cost:0.487624, acc:0.835938
    47. epoch:8, test\_cost:0.495487, test\_acc:0.831584
    48. epoch:9, batch:0, cost:0.523822, acc:0.875000
    49. epoch:9, batch:100, cost:0.604066, acc:0.812500
    50. epoch:9, batch:200, cost:0.420712, acc:0.835938
    51. epoch:9, batch:300, cost:0.490420, acc:0.867188
    52. epoch:9, test\_cost:0.480619, test\_acc:0.838099
    53. epoch:10, batch:0, cost:0.337900, acc:0.875000
    54. epoch:10, batch:100, cost:0.513283, acc:0.820312
    55. epoch:10, batch:200, cost:0.471449, acc:0.835938
    56. epoch:10, batch:300, cost:0.387642, acc:0.859375
    57. epoch:10, test\_cost:0.470769, test\_acc:0.842698
    58. epoch:11, batch:0, cost:0.448336, acc:0.835938
    59. epoch:11, batch:100, cost:0.440521, acc:0.843750
    60. epoch:11, batch:200, cost:0.436573, acc:0.875000
    61. epoch:11, batch:300, cost:0.694043, acc:0.781250
    62. epoch:11, test\_cost:0.460047, test\_acc:0.846606
    63. epoch:12, batch:0, cost:0.481746, acc:0.882812
    64. epoch:12, batch:100, cost:0.552047, acc:0.804688
    65. epoch:12, batch:200, cost:0.386960, acc:0.882812
    66. epoch:12, batch:300, cost:0.341771, acc:0.882812
    67. epoch:12, test\_cost:0.450299, test\_acc:0.850769
    68. epoch:13, batch:0, cost:0.547187, acc:0.859375
    69. epoch:13, batch:100, cost:0.356574, acc:0.875000
    70. epoch:13, batch:200, cost:0.379301, acc:0.867188
    71. epoch:13, batch:300, cost:0.374310, acc:0.898438
    72. epoch:13, test\_cost:0.440953, test\_acc:0.856413
    73. epoch:14, batch:0, cost:0.689324, acc:0.843750
    74. epoch:14, batch:100, cost:0.493557, acc:0.851562
    75. epoch:14, batch:200, cost:0.332372, acc:0.898438
    76. epoch:14, batch:300, cost:0.462879, acc:0.867188
    77. epoch:14, test\_cost:0.434527, test\_acc:0.858323
    78. epoch:15, batch:0, cost:0.539430, acc:0.859375
    79. epoch:15, batch:100, cost:0.610829, acc:0.828125
    80. epoch:15, batch:200, cost:0.567205, acc:0.773438
    81. epoch:15, batch:300, cost:0.448473, acc:0.851562
    82. epoch:15, test\_cost:0.429657, test\_acc:0.858929
    83. epoch:16, batch:0, cost:0.275449, acc:0.914062
    84. epoch:16, batch:100, cost:0.322028, acc:0.890625
    85. epoch:16, batch:200, cost:0.433238, acc:0.843750
    86. epoch:16, batch:300, cost:0.435924, acc:0.875000
    87. epoch:16, test\_cost:0.420777, test\_acc:0.864314
    88. epoch:17, batch:0, cost:0.463425, acc:0.859375
    89. epoch:17, batch:100, cost:0.498210, acc:0.890625
    90. epoch:17, batch:200, cost:0.347210, acc:0.898438
    91. epoch:17, batch:300, cost:0.375353, acc:0.867188
    92. epoch:17, test\_cost:0.414920, test\_acc:0.866483
    93. epoch:18, batch:0, cost:0.371144, acc:0.906250
    94. epoch:18, batch:100, cost:0.511294, acc:0.828125
    95. epoch:18, batch:200, cost:0.431728, acc:0.828125
    96. epoch:18, batch:300, cost:0.505222, acc:0.843750
    97. epoch:18, test\_cost:0.412018, test\_acc:0.866136
    98. epoch:19, batch:0, cost:0.417319, acc:0.859375
    99. epoch:19, batch:100, cost:0.405875, acc:0.867188
    100. epoch:19, batch:200, cost:0.466319, acc:0.843750
    101. epoch:19, batch:300, cost:0.524598, acc:0.820312
    102. epoch:19, test\_cost:0.408254, test\_acc:0.870387
    103. epoch:20, batch:0, cost:0.278774, acc:0.921875
    104. epoch:20, batch:100, cost:0.375402, acc:0.875000
    105. epoch:20, batch:200, cost:0.512493, acc:0.851562
    106. epoch:20, batch:300, cost:0.352869, acc:0.867188
    107. epoch:20, test\_cost:0.402862, test\_acc:0.870646
    108. epoch:21, batch:0, cost:0.328388, acc:0.890625
    109. epoch:21, batch:100, cost:0.474930, acc:0.843750
    110. epoch:21, batch:200, cost:0.279459, acc:0.898438
    111. epoch:21, batch:300, cost:0.480916, acc:0.843750
    112. epoch:21, test\_cost:0.398193, test\_acc:0.870476
    113. epoch:22, batch:0, cost:0.360476, acc:0.914062
    114. epoch:22, batch:100, cost:0.399123, acc:0.867188
    115. epoch:22, batch:200, cost:0.330940, acc:0.898438
    116. epoch:22, batch:300, cost:0.449070, acc:0.851562
    117. epoch:22, test\_cost:0.396272, test\_acc:0.872644
    118. epoch:23, batch:0, cost:0.311765, acc:0.882812
    119. epoch:23, batch:100, cost:0.430598, acc:0.859375
    120. epoch:23, batch:200, cost:0.371466, acc:0.867188
    121. epoch:23, batch:300, cost:0.497460, acc:0.859375
    122. epoch:23, test\_cost:0.391935, test\_acc:0.874990
    123. epoch:24, batch:0, cost:0.278461, acc:0.921875
    124. epoch:24, batch:100, cost:0.384332, acc:0.867188
    125. epoch:24, batch:200, cost:0.687089, acc:0.804688
    126. epoch:24, batch:300, cost:0.465835, acc:0.835938
    127. epoch:24, test\_cost:0.386384, test\_acc:0.874905
    128. epoch:25, batch:0, cost:0.359800, acc:0.914062
    129. epoch:25, batch:100, cost:0.370942, acc:0.906250
    130. epoch:25, batch:200, cost:0.343612, acc:0.906250
    131. epoch:25, batch:300, cost:0.373149, acc:0.859375
    132. epoch:25, test\_cost:0.385754, test\_acc:0.875249
    133. epoch:26, batch:0, cost:0.359912, acc:0.859375
    134. epoch:26, batch:100, cost:0.299233, acc:0.906250
    135. epoch:26, batch:200, cost:0.321898, acc:0.882812
    136. epoch:26, batch:300, cost:0.506139, acc:0.820312
    137. epoch:26, test\_cost:0.382092, test\_acc:0.877597
    138. epoch:27, batch:0, cost:0.438806, acc:0.882812
    139. epoch:27, batch:100, cost:0.351698, acc:0.867188
    140. epoch:27, batch:200, cost:0.413263, acc:0.875000
    141. epoch:27, batch:300, cost:0.327677, acc:0.875000
    142. epoch:27, test\_cost:0.379122, test\_acc:0.880460
    143. epoch:28, batch:0, cost:0.329184, acc:0.921875
    144. epoch:28, batch:100, cost:0.489258, acc:0.882812
    145. epoch:28, batch:200, cost:0.375317, acc:0.890625
    146. epoch:28, batch:300, cost:0.355702, acc:0.859375
    147. epoch:28, test\_cost:0.377964, test\_acc:0.881066
    148. epoch:29, batch:0, cost:0.360147, acc:0.882812
    149. epoch:29, batch:100, cost:0.361545, acc:0.906250
    150. epoch:29, batch:200, cost:0.535644, acc:0.812500
    151. epoch:29, batch:300, cost:0.463827, acc:0.789062
    152. epoch:29, test\_cost:0.374992, test\_acc:0.879156
    153. epoch:30, batch:0, cost:0.386321, acc:0.843750
    154. epoch:30, batch:100, cost:0.450116, acc:0.851562
    155. epoch:30, batch:200, cost:0.380319, acc:0.867188
    156. epoch:30, batch:300, cost:0.357393, acc:0.914062
    157. epoch:30, test\_cost:0.372232, test\_acc:0.880198
    158. epoch:31, batch:0, cost:0.338851, acc:0.882812
    159. epoch:31, batch:100, cost:0.418707, acc:0.890625
    160. epoch:31, batch:200, cost:0.349568, acc:0.875000
    161. epoch:31, batch:300, cost:0.414638, acc:0.882812
    162. epoch:31, test\_cost:0.373127, test\_acc:0.879245
    163. epoch:32, batch:0, cost:0.278832, acc:0.906250
    164. epoch:32, batch:100, cost:0.538143, acc:0.851562
    165. epoch:32, batch:200, cost:0.418359, acc:0.890625
    166. epoch:32, batch:300, cost:0.510367, acc:0.875000
    167. epoch:32, test\_cost:0.370239, test\_acc:0.880896
    168. epoch:33, batch:0, cost:0.410598, acc:0.835938
    169. epoch:33, batch:100, cost:0.295002, acc:0.906250
    170. epoch:33, batch:200, cost:0.430560, acc:0.828125
    171. epoch:33, batch:300, cost:0.417476, acc:0.859375
    172. epoch:33, test\_cost:0.367410, test\_acc:0.881155
    173. epoch:34, batch:0, cost:0.337740, acc:0.937500
    174. epoch:34, batch:100, cost:0.304080, acc:0.906250
    175. epoch:34, batch:200, cost:0.359049, acc:0.890625
    176. epoch:34, batch:300, cost:0.373999, acc:0.890625
    177. epoch:34, test\_cost:0.367002, test\_acc:0.880113
    178. epoch:35, batch:0, cost:0.411581, acc:0.898438
    179. epoch:35, batch:100, cost:0.400797, acc:0.851562
    180. epoch:35, batch:200, cost:0.482271, acc:0.828125
    181. epoch:35, batch:300, cost:0.340450, acc:0.890625
    182. epoch:35, test\_cost:0.363663, test\_acc:0.883068
    183. epoch:36, batch:0, cost:0.338912, acc:0.875000
    184. epoch:36, batch:100, cost:0.416916, acc:0.867188
    185. epoch:36, batch:200, cost:0.313621, acc:0.882812
    186. epoch:36, batch:300, cost:0.677497, acc:0.796875
    187. epoch:36, test\_cost:0.361819, test\_acc:0.882983
    188. epoch:37, batch:0, cost:0.329249, acc:0.867188
    189. epoch:37, batch:100, cost:0.375915, acc:0.890625
    190. epoch:37, batch:200, cost:0.290267, acc:0.906250
    191. epoch:37, batch:300, cost:0.388264, acc:0.859375
    192. epoch:37, test\_cost:0.363713, test\_acc:0.880025
    193. epoch:38, batch:0, cost:0.452093, acc:0.875000
    194. epoch:38, batch:100, cost:0.237014, acc:0.898438
    195. epoch:38, batch:200, cost:0.334976, acc:0.898438
    196. epoch:38, batch:300, cost:0.386618, acc:0.875000
    197. epoch:38, test\_cost:0.357681, test\_acc:0.884889
    198. epoch:39, batch:0, cost:0.397014, acc:0.867188
    199. epoch:39, batch:100, cost:0.387132, acc:0.882812
    200. epoch:39, batch:200, cost:0.262646, acc:0.921875
    201. epoch:39, batch:300, cost:0.295718, acc:0.906250
    202. epoch:39, test\_cost:0.358814, test\_acc:0.884542
    203. epoch:40, batch:0, cost:0.336061, acc:0.875000
    204. epoch:40, batch:100, cost:0.393282, acc:0.867188
    205. epoch:40, batch:200, cost:0.453071, acc:0.867188
    206. epoch:40, batch:300, cost:0.276213, acc:0.921875
    207. epoch:40, test\_cost:0.355846, test\_acc:0.886278
    208. epoch:41, batch:0, cost:0.362588, acc:0.867188
    209. epoch:41, batch:100, cost:0.293396, acc:0.914062
    210. epoch:41, batch:200, cost:0.351766, acc:0.890625
    211. epoch:41, batch:300, cost:0.437711, acc:0.820312
    212. epoch:41, test\_cost:0.356017, test\_acc:0.886799
    213. epoch:42, batch:0, cost:0.431722, acc:0.843750
    214. epoch:42, batch:100, cost:0.296809, acc:0.914062
    215. epoch:42, batch:200, cost:0.300333, acc:0.898438
    216. epoch:42, batch:300, cost:0.392034, acc:0.859375
    217. epoch:42, test\_cost:0.354504, test\_acc:0.885580
    218. epoch:43, batch:0, cost:0.237395, acc:0.945312
    219. epoch:43, batch:100, cost:0.274653, acc:0.914062
    220. epoch:43, batch:200, cost:0.320165, acc:0.898438
    221. epoch:43, batch:300, cost:0.233366, acc:0.937500
    222. epoch:43, test\_cost:0.352862, test\_acc:0.885410
    223. epoch:44, batch:0, cost:0.309431, acc:0.953125
    224. epoch:44, batch:100, cost:0.371803, acc:0.843750
    225. epoch:44, batch:200, cost:0.309721, acc:0.898438
    226. epoch:44, batch:300, cost:0.330030, acc:0.898438
    227. epoch:44, test\_cost:0.348967, test\_acc:0.888017
    228. epoch:45, batch:0, cost:0.382172, acc:0.890625
    229. epoch:45, batch:100, cost:0.292855, acc:0.929688
    230. epoch:45, batch:200, cost:0.445127, acc:0.898438
    231. epoch:45, batch:300, cost:0.365554, acc:0.890625
    232. epoch:45, test\_cost:0.352218, test\_acc:0.883932
    233. epoch:46, batch:0, cost:0.424743, acc:0.898438
    234. epoch:46, batch:100, cost:0.382699, acc:0.859375
    235. epoch:46, batch:200, cost:0.319472, acc:0.914062
    236. epoch:46, batch:300, cost:0.414162, acc:0.859375
    237. epoch:46, test\_cost:0.349987, test\_acc:0.885498
    238. epoch:47, batch:0, cost:0.304131, acc:0.890625
    239. epoch:47, batch:100, cost:0.386861, acc:0.890625
    240. epoch:47, batch:200, cost:0.608894, acc:0.820312
    241. epoch:47, batch:300, cost:0.281832, acc:0.898438
    242. epoch:47, test\_cost:0.349286, test\_acc:0.888276
    243. epoch:48, batch:0, cost:0.406423, acc:0.882812
    244. epoch:48, batch:100, cost:0.398680, acc:0.898438
    245. epoch:48, batch:200, cost:0.291706, acc:0.914062
    246. epoch:48, batch:300, cost:0.358105, acc:0.875000
    247. epoch:48, test\_cost:0.348130, test\_acc:0.888361
    248. epoch:49, batch:0, cost:0.284720, acc:0.914062
    249. epoch:49, batch:100, cost:0.341173, acc:0.898438
    250. epoch:49, batch:200, cost:0.341595, acc:0.859375
    251. epoch:49, batch:300, cost:0.442754, acc:0.820312
    252. epoch:49, test\_cost:0.347218, test\_acc:0.886012
    253. epoch:50, batch:0, cost:0.311721, acc:0.906250
    254. epoch:50, batch:100, cost:0.326822, acc:0.875000
    255. epoch:50, batch:200, cost:0.331799, acc:0.898438
    256. epoch:50, batch:300, cost:0.426647, acc:0.851562
    257. epoch:50, test\_cost:0.347288, test\_acc:0.888535
    258. epoch:51, batch:0, cost:0.389481, acc:0.867188
    259. epoch:51, batch:100, cost:0.289127, acc:0.906250
    260. epoch:51, batch:200, cost:0.328051, acc:0.929688
    261. epoch:51, batch:300, cost:0.426396, acc:0.890625
    262. epoch:51, test\_cost:0.344246, test\_acc:0.889839
    263. epoch:52, batch:0, cost:0.288156, acc:0.906250
    264. epoch:52, batch:100, cost:0.298805, acc:0.906250
    265. epoch:52, batch:200, cost:0.371176, acc:0.921875
    266. epoch:52, batch:300, cost:0.389306, acc:0.875000
    267. epoch:52, test\_cost:0.345692, test\_acc:0.891224
    268. epoch:53, batch:0, cost:0.425932, acc:0.890625
    269. epoch:53, batch:100, cost:0.415528, acc:0.882812
    270. epoch:53, batch:200, cost:0.434767, acc:0.867188
    271. epoch:53, batch:300, cost:0.331441, acc:0.914062
    272. epoch:53, test\_cost:0.340924, test\_acc:0.890101
    273. epoch:54, batch:0, cost:0.260270, acc:0.906250
    274. epoch:54, batch:100, cost:0.305412, acc:0.898438
    275. epoch:54, batch:200, cost:0.330370, acc:0.906250
    276. epoch:54, batch:300, cost:0.334084, acc:0.898438
    277. epoch:54, test\_cost:0.341799, test\_acc:0.892010
    278. epoch:55, batch:0, cost:0.239946, acc:0.937500
    279. epoch:55, batch:100, cost:0.510334, acc:0.898438
    280. epoch:55, batch:200, cost:0.331789, acc:0.898438
    281. epoch:55, batch:300, cost:0.273344, acc:0.898438
    282. epoch:55, test\_cost:0.341348, test\_acc:0.889403
    283. epoch:56, batch:0, cost:0.288282, acc:0.914062
    284. epoch:56, batch:100, cost:0.384843, acc:0.898438
    285. epoch:56, batch:200, cost:0.391903, acc:0.867188
    286. epoch:56, batch:300, cost:0.352458, acc:0.882812
    287. epoch:56, test\_cost:0.338860, test\_acc:0.891054
    288. epoch:57, batch:0, cost:0.434810, acc:0.828125
    289. epoch:57, batch:100, cost:0.257800, acc:0.953125
    290. epoch:57, batch:200, cost:0.283473, acc:0.921875
    291. epoch:57, batch:300, cost:0.337173, acc:0.867188
    292. epoch:57, test\_cost:0.339060, test\_acc:0.891575
    293. epoch:58, batch:0, cost:0.240891, acc:0.898438
    294. epoch:58, batch:100, cost:0.390225, acc:0.875000
    295. epoch:58, batch:200, cost:0.393483, acc:0.843750
    296. epoch:58, batch:300, cost:0.289487, acc:0.890625
    297. epoch:58, test\_cost:0.337302, test\_acc:0.892269
    298. epoch:59, batch:0, cost:0.210337, acc:0.960938
    299. epoch:59, batch:100, cost:0.423231, acc:0.867188
    300. epoch:59, batch:200, cost:0.319490, acc:0.921875
    301. epoch:59, batch:300, cost:0.451494, acc:0.859375
    302. epoch:59, test\_cost:0.336483, test\_acc:0.893137
    303. epoch:60, batch:0, cost:0.231775, acc:0.937500
    304. epoch:60, batch:100, cost:0.295306, acc:0.906250
    305. epoch:60, batch:200, cost:0.378960, acc:0.859375
    306. epoch:60, batch:300, cost:0.350808, acc:0.843750
    307. epoch:60, test\_cost:0.335058, test\_acc:0.894267
    308. epoch:61, batch:0, cost:0.440865, acc:0.867188
    309. epoch:61, batch:100, cost:0.270725, acc:0.882812
    310. epoch:61, batch:200, cost:0.398181, acc:0.851562
    311. epoch:61, batch:300, cost:0.363882, acc:0.921875
    312. epoch:61, test\_cost:0.336761, test\_acc:0.892875
    313. epoch:62, batch:0, cost:0.321757, acc:0.898438
    314. epoch:62, batch:100, cost:0.330311, acc:0.890625
    315. epoch:62, batch:200, cost:0.406124, acc:0.851562
    316. epoch:62, batch:300, cost:0.275819, acc:0.898438
    317. epoch:62, test\_cost:0.342463, test\_acc:0.891565
    318. epoch:63, batch:0, cost:0.321822, acc:0.898438
    319. epoch:63, batch:100, cost:0.322195, acc:0.882812
    320. epoch:63, batch:200, cost:0.432605, acc:0.882812
    321. epoch:63, batch:300, cost:0.377368, acc:0.898438
    322. epoch:63, test\_cost:0.333785, test\_acc:0.895221
    323. epoch:64, batch:0, cost:0.247617, acc:0.882812
    324. epoch:64, batch:100, cost:0.231372, acc:0.921875
    325. epoch:64, batch:200, cost:0.336805, acc:0.867188
    326. epoch:64, batch:300, cost:0.274635, acc:0.898438
    327. epoch:64, test\_cost:0.332033, test\_acc:0.894179
    328. epoch:65, batch:0, cost:0.241076, acc:0.906250
    329. epoch:65, batch:100, cost:0.377462, acc:0.906250
    330. epoch:65, batch:200, cost:0.297226, acc:0.882812
    331. epoch:65, batch:300, cost:0.440397, acc:0.867188
    332. epoch:65, test\_cost:0.330794, test\_acc:0.897045
    333. epoch:66, batch:0, cost:0.266126, acc:0.898438
    334. epoch:66, batch:100, cost:0.390715, acc:0.859375
    335. epoch:66, batch:200, cost:0.292437, acc:0.914062
    336. epoch:66, batch:300, cost:0.395078, acc:0.867188
    337. epoch:66, test\_cost:0.330902, test\_acc:0.895221
    338. epoch:67, batch:0, cost:0.301438, acc:0.929688
    339. epoch:67, batch:100, cost:0.388324, acc:0.898438
    340. epoch:67, batch:200, cost:0.439915, acc:0.890625
    341. epoch:67, batch:300, cost:0.310547, acc:0.867188
    342. epoch:67, test\_cost:0.330386, test\_acc:0.896521
    343. epoch:68, batch:0, cost:0.243119, acc:0.929688
    344. epoch:68, batch:100, cost:0.447522, acc:0.875000
    345. epoch:68, batch:200, cost:0.470691, acc:0.882812
    346. epoch:68, batch:300, cost:0.296465, acc:0.882812
    347. epoch:68, test\_cost:0.326098, test\_acc:0.896266
    348. epoch:69, batch:0, cost:0.260604, acc:0.898438
    349. epoch:69, batch:100, cost:0.417193, acc:0.882812
    350. epoch:69, batch:200, cost:0.483119, acc:0.835938
    351. epoch:69, batch:300, cost:0.405713, acc:0.875000
    352. epoch:69, test\_cost:0.328661, test\_acc:0.896957
    353. epoch:70, batch:0, cost:0.300975, acc:0.882812
    354. epoch:70, batch:100, cost:0.199427, acc:0.945312
    355. epoch:70, batch:200, cost:0.207260, acc:0.937500
    356. epoch:70, batch:300, cost:0.199148, acc:0.914062
    357. epoch:70, test\_cost:0.327545, test\_acc:0.894958
    358. epoch:71, batch:0, cost:0.281955, acc:0.914062
    359. epoch:71, batch:100, cost:0.267508, acc:0.914062
    360. epoch:71, batch:200, cost:0.561389, acc:0.828125
    361. epoch:71, batch:300, cost:0.377676, acc:0.867188
    362. epoch:71, test\_cost:0.325637, test\_acc:0.897740
    363. epoch:72, batch:0, cost:0.348661, acc:0.890625
    364. epoch:72, batch:100, cost:0.346154, acc:0.898438
    365. epoch:72, batch:200, cost:0.447819, acc:0.867188
    366. epoch:72, batch:300, cost:0.342514, acc:0.929688
    367. epoch:72, test\_cost:0.325294, test\_acc:0.897304
    368. epoch:73, batch:0, cost:0.223638, acc:0.929688
    369. epoch:73, batch:100, cost:0.394560, acc:0.859375
    370. epoch:73, batch:200, cost:0.341260, acc:0.890625
    371. epoch:73, batch:300, cost:0.283185, acc:0.898438
    372. epoch:73, test\_cost:0.326340, test\_acc:0.895394
    373. epoch:74, batch:0, cost:0.371942, acc:0.921875
    374. epoch:74, batch:100, cost:0.333636, acc:0.882812
    375. epoch:74, batch:200, cost:0.397030, acc:0.875000
    376. epoch:74, batch:300, cost:0.392802, acc:0.875000
    377. epoch:74, test\_cost:0.322571, test\_acc:0.896089
    378. epoch:75, batch:0, cost:0.275930, acc:0.921875
    379. epoch:75, batch:100, cost:0.263152, acc:0.914062
    380. epoch:75, batch:200, cost:0.296550, acc:0.898438
    381. epoch:75, batch:300, cost:0.402121, acc:0.898438
    382. epoch:75, test\_cost:0.320611, test\_acc:0.897134
    383. epoch:76, batch:0, cost:0.279775, acc:0.921875
    384. epoch:76, batch:100, cost:0.439274, acc:0.843750
    385. epoch:76, batch:200, cost:0.330266, acc:0.898438
    386. epoch:76, batch:300, cost:0.418308, acc:0.851562
    387. epoch:76, test\_cost:0.320242, test\_acc:0.900429
    388. epoch:77, batch:0, cost:0.320668, acc:0.890625
    389. epoch:77, batch:100, cost:0.168939, acc:0.960938
    390. epoch:77, batch:200, cost:0.244379, acc:0.953125
    391. epoch:77, batch:300, cost:0.621534, acc:0.875000
    392. epoch:77, test\_cost:0.319756, test\_acc:0.900865
    393. epoch:78, batch:0, cost:0.284392, acc:0.914062
    394. epoch:78, batch:100, cost:0.309243, acc:0.890625
    395. epoch:78, batch:200, cost:0.273962, acc:0.945312
    396. epoch:78, batch:300, cost:0.311928, acc:0.906250
    397. epoch:78, test\_cost:0.318491, test\_acc:0.901818
    398. epoch:79, batch:0, cost:0.242170, acc:0.898438
    399. epoch:79, batch:100, cost:0.315753, acc:0.875000
    400. epoch:79, batch:200, cost:0.252874, acc:0.937500
    401. epoch:79, batch:300, cost:0.447730, acc:0.812500
    402. epoch:79, test\_cost:0.318828, test\_acc:0.900603
    403. 模型保存成功.

    【推理预测】

    1. model_save_dir = "model/"
    2. def get_data(sentence): # 将传入的句子根据字典中的值进行编码
    3. with open(dict_file_path, "r", encoding="utf-8") as f:
    4. dict_txt = eval(f.readlines()[0])
    5. ret = [] # 编码结果
    6. keys = dict_txt.keys()
    7. for w in sentence: # 取出每个字
    8. if not w in keys: # 字不在字典中
    9. w = "<unk>"
    10. ret.append(int(dict_txt[w]))
    11. return ret
    12. # 执行器
    13. place = fluid.CPUPlace()
    14. exe = fluid.Executor(place)
    15. exe.run(fluid.default_startup_program())
    16. infer_program, feed_names, target_var = \
    17. fluid.io.load_inference_model(model_save_dir, exe)
    18. texts = [] # 存放待预测句子
    19. data1 = get_data("在获得诺贝尔文学奖7年之后,莫言15日晚间在山西汾阳贾家庄如是说")
    20. data2 = get_data("综合'今日美国'、《世界日报》等当地媒体报道,芝加哥河滨警察局表示")
    21. data3 = get_data("中国队2022年冬奥会表现优秀")
    22. data4 = get_data("中国人民银行今日发布通知,降低准备金率,预计释放4000亿流动性")
    23. data5 = get_data("10月20日,第六届世界互联网大会正式开幕")
    24. data6 = get_data("同一户型,为什么高层比低层要贵那么多?")
    25. data7 = get_data("揭秘A股周涨5%资金动向:追捧2类股,抛售600亿香饽饽")
    26. data8 = get_data("宋慧乔陷入感染危机,前夫宋仲基不戴口罩露面,身处国外神态轻松")
    27. data9 = get_data("此盆栽花很好养,花美似牡丹,三季开花,南北都能养,很值得栽培") # 不属于任何一个类别
    28. texts.append(data1)
    29. texts.append(data2)
    30. texts.append(data3)
    31. texts.append(data4)
    32. texts.append(data5)
    33. texts.append(data6)
    34. texts.append(data7)
    35. texts.append(data8)
    36. texts.append(data9)
    37. base_shape = [[len(c) for c in texts]] # 计算每个句子长度
    38. tensor_words = fluid.create_lod_tensor(texts, base_shape, place)
    39. result = exe.run(infer_program,
    40. feed={feed_names[0]: tensor_words},
    41. fetch_list=target_var)
    42. names = ["文化", "娱乐", "体育", "财经", "房产","汽车", "教育", "科技", "国际", "证券"]
    43. for r in result[0]:
    44. idx = np.argmax(r) # 取出最大值的索引
    45. print("预测结果:", names[idx], " 概率:", r[idx])

    输出

    1. 预测结果: 财经 概率: 0.81440145
    2. 预测结果: 娱乐 概率: 1.0
    3. 预测结果: 财经 概率: 1.0
    4. 预测结果: 汽车 概率: 0.9996093
    5. 预测结果: 文化 概率: 0.9404757
    6. 预测结果: 娱乐 概率: 0.8715788
    7. 预测结果: 房产 概率: 0.9625704
    8. 预测结果: 科技 概率: 0.985617
    9. 预测结果: 房产 概率: 1.0

    文章涉及到的数据资源链接如下:news_classify_data.zip - 蓝奏云文件大小:2.7 M|icon-default.png?t=N7T8https://wwt.lanzoum.com/iTF6N1q1rnti

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/362150
推荐阅读
相关标签
  

闽ICP备14008679号