赞
踩
文本分类就是根据文本内容将文本划分到不同类别,例如新闻系统中,每篇新闻报道会划归到不同的类别。
内容分类(新闻分类)
邮件过滤(例如垃圾邮件过滤)
用户分类(如商城消费级别、喜好)
评论、文章、对话的情感分类(正面、负面、中性)
任务:建立文本分类模型,并对模型进行训练、评估,从而实现对中文新闻摘要类别正确划分
数据集:从网站上爬取56821条数据中文新闻摘要,包含10种类别,国际、文化、娱乐、体育、财经、汽车、教育、科技、房产、证券,各类别样本数量如下表所示:
模型选择:
步骤:
代码
【预处理部分】
- ########################### 数据预处理 #########################
- import os
- from multiprocessing import cpu_count
- import numpy as np
- import paddle
- import paddle.fluid as fluid
-
- # 定义一组公共变量
- data_root = "data/" # 数据集所在目录
- data_file = "news_classify_data.txt" # 原始数据集
- train_file = "train.txt" # 训练集文件
- test_file = "test.txt" # 测试集文件
- dict_file = "dict_txt.txt" # 字典文件(存放字和编码映射关系)
-
- data_file_path = data_root + data_file # 数据集完整路径
- train_file_path = data_root + train_file # 训练集文件完整路径
- test_file_path = data_root + test_file # 测试集文件完整路径
- dict_file_path = data_root + dict_file # 字典文件完整路径
-
- # 取出样本中所有字,对每个字进行编码,将编码结果存入字典文件
- def create_dict():
- dict_set = set() # 集合,用作去重
- with open(data_file_path, "r", encoding="utf-8") as f:
- for line in f.readlines(): # 遍历每行
- line = line.replace("\n", "") # 去除换行符
- tmp_list = line.split("_!_") # 根据分隔符拆分
- title = tmp_list[-1] # 最后一个字段即为标题
- for word in title: # 取出每个字
- dict_set.add(word)
-
- # 遍历集合,取出每个字进行编号
- dict_txt = {} # 定义字典
- i = 1 # 编码使用的计数器
- for word in dict_set:
- dict_txt[word] = i # 字-编码 键值对添加到字典
- i += 1
-
- dict_txt["<unk>"] = i # 未知字符(在样本中未出现过的字)
-
- # 将字典内容存入文件
- with open(dict_file_path, "w", encoding="utf-8") as f:
- f.write(str(dict_txt))
-
- print("生成字典结束.")
-
- # 传入一个句子,将每个字替换为编码值,和标签一起返回
- def line_encoding(title, dict_txt, label):
- new_line = "" # 编码结果
- for word in title:
- if word in dict_txt: # 在字典中
- code = str(dict_txt[word]) # 取出编码值
- else: # 不在字典中
- code = str(dict_txt["<unk>"]) # 取未知字符编码值
- new_line = new_line + code + "," # 追加到字符串后面
- new_line = new_line[:-1] # 去掉最后一个多余的逗号
- new_line = new_line + "\t" + label + "\n" # 追加标签值
- return new_line
-
- # 读取原始样本,取出标题部分进行编码,将编码后的划分测试集/训练集
- def create_train_test_file():
- # 清空训练集/测试集
- with open(train_file_path, "w") as f:
- pass
- with open(test_file_path, "w") as f:
- pass
-
- # 读取字典文件
- with open(dict_file_path, "r", encoding="utf-8") as f_dict:
- dict_txt = eval(f_dict.readlines()[0]) # 读取字典文件第一行,生成字典对象
-
- # 读取原始样本
- with open(data_file_path, "r", encoding="utf-8") as f_data:
- lines = f_data.readlines()
-
- i = 0
- for line in lines:
- tmp_list = line.replace("\n", "").split("_!_") # 拆分
- title = tmp_list[3] # 标题
- label = tmp_list[1] # 类别
- new_line = line_encoding(title, dict_txt, label) # 对标题编码
-
- if i % 10 == 0: # 写入测试集
- with open(test_file_path, "a", encoding="utf-8") as f:
- f.write(new_line)
- else: # 写入训练集
- with open(train_file_path, "a", encoding="utf-8") as f:
- f.write(new_line)
- i += 1
- print("生成训练集/测试集结束.")
-
- create_dict() # 根据样本生成字典
- create_train_test_file()
输出:
- 生成字典结束.
- 生成训练集/测试集结束.
【模型定义与训练】
- paddle.enable_static()
- # 读取字典文件,返回字典长度
- def get_dict_len(dict_path):
- with open(dict_path, "r", encoding="utf-8") as f:
- dict_txt = eval(f.readlines()[0])
- return len(dict_txt.keys())
-
- def data_mapper(sample):
- data, label = sample # 赋值到变量
- val = [int(w) for w in data.split(",")] # 将编码值转换位数字(从文件读取为字符串)
- return val, int(label)
-
- def train_reader(train_file_path): # 训练集读取器
- def reader():
- with open(train_file_path, "r") as f:
- lines = f.readlines()
- np.random.shuffle(lines) # 随机化处理
- for line in lines:
- data, label = line.split("\t") # 拆分
- yield data, label
- return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
-
- def test_reader(test_file_path): # 训练集读取器
- def reader():
- with open(test_file_path, "r") as f:
- lines = f.readlines()
-
- for line in lines:
- data, label = line.split("\t") # 拆分
- yield data, label
- return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
-
- # 定义网络
- def Text_CNN(data, dict_dim, class_dim=10, emb_dim=128,
- hid_dim=128, hid_dim2=128):
- """
- 定义TextCNN模型
- :param data: 输入
- :param dict_dim: 词典大小(词语总的数量)
- :param class_dim: 分类的数量
- :param emb_dim: 词嵌入长度
- :param hid_dim: 第一个卷基层卷积核数量
- :param hid_dim2: 第二个卷基层卷积核数量
- :return: 模型预测结果
- """
- # embedding层
- emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
- # 并列两个卷积/池化层
- conv1 = fluid.nets.sequence_conv_pool(input=emb, # 输入(词嵌入层输出)
- num_filters=hid_dim,# 卷积核数量
- filter_size=3,#卷积核大小
- act="tanh",#激活函数
- pool_type="sqrt")#池化类型
- conv2 = fluid.nets.sequence_conv_pool(input=emb, # 输入(词嵌入层输出)
- num_filters=hid_dim2,# 卷积核数量
- filter_size=4,#卷积核大小
- act="tanh",#激活函数
- pool_type="sqrt")#池化类型
- # fc
- output = fluid.layers.fc(input=[conv1, conv2], # 输入
- size=class_dim,#输出值个数
- act="softmax")#激活函数
- return output
-
- # 定义占位符张量
- words = fluid.layers.data(name="words",
- shape=[1],
- dtype="int64",
- lod_level=1) # LOD张量用来表示变长数据
- label = fluid.layers.data(name="label",
- shape=[1],
- dtype="int64")
- dict_dim = get_dict_len(dict_file_path) # 获取字典长度
- # 调用模型函数
- model = Text_CNN(words, dict_dim)
- # 损失函数
- cost = fluid.layers.cross_entropy(input=model, label=label)
- avg_cost = fluid.layers.mean(cost)
- # 优化器
- optimizer = fluid.optimizer.Adam(learning_rate=0.0001)
- optimizer.minimize(avg_cost)
- # 准确率
- accuracy = fluid.layers.accuracy(input=model, label=label)
-
- # 执行器
- place = fluid.CUDAPlace(0)
- exe = fluid.Executor(place)
- exe.run(fluid.default_startup_program())
-
- # reader
- ## 训练集reader
- tr_reader = train_reader(train_file_path)
- batch_train_reader = paddle.batch(tr_reader, batch_size=128)
- ## 测试集reader
- ts_reader = test_reader(test_file_path)
- batch_test_reader = paddle.batch(ts_reader, batch_size=128)
-
- # feeder
- feeder = fluid.DataFeeder(place=place, feed_list=[words, label])
-
- # 开始训练
- for epoch in range(80): # 外层循环控制训练轮次
- for batch_id, data in enumerate(batch_train_reader()): # 内层循环控制批次
- train_cost, train_acc = exe.run(fluid.default_main_program(),#program
- feed=feeder.feed(data),#喂入的参数
- fetch_list=[avg_cost, accuracy])#返回值
- if batch_id % 100 == 0:
- print("epoch:%d, batch:%d, cost:%f, acc:%f" %
- (epoch, batch_id, train_cost[0], train_acc[0]))
-
- # 每轮训练结束后进行模型评估
- test_costs_list = [] # 存放测试集损失值
- test_accs_list = [] # 存放测试集准确率
-
- for batch_id, data in enumerate(batch_test_reader()):
- test_cost, test_acc = exe.run(fluid.default_main_program(),
- feed=feeder.feed(data),
- fetch_list=[avg_cost, accuracy])
- test_costs_list.append(test_cost[0])
- test_accs_list.append(test_acc[0])
- # 计算所有批次损失值/准确率均值
- avg_test_cost = sum(test_costs_list) / len(test_costs_list)
- avg_test_acc = sum(test_accs_list) / len(test_accs_list)
- print("epoch:%d, test_cost:%f, test_acc:%f" %
- (epoch, avg_test_cost, avg_test_acc))
-
- # 训练结束,保存模型
- model_save_dir = "model/"
- if not os.path.exists(model_save_dir):
- os.makedirs(model_save_dir)
- fluid.io.save_inference_model(model_save_dir, # 保存路径
- feeded_var_names=[words.name],# 使用时传入参数名称
- target_vars=[model],#预测结果
- executor=exe)#执行器
- print("模型保存成功.")
输出
- W0301 17:44:27.392561 134 gpu\_resources.cc:61\] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 11.2
- W0301 17:44:27.396858 134 gpu\_resources.cc:91\] device: 0, cuDNN Version: 8.2.
-
- epoch:0, batch:0, cost:2.311182, acc:0.062500
- epoch:0, batch:100, cost:2.063974, acc:0.398438
- epoch:0, batch:200, cost:1.564050, acc:0.531250
- epoch:0, batch:300, cost:1.113311, acc:0.726562
- epoch:0, test\_cost:0.984438, test\_acc:0.699619
- epoch:1, batch:0, cost:0.904973, acc:0.703125
- epoch:1, batch:100, cost:1.056342, acc:0.648438
- epoch:1, batch:200, cost:0.757222, acc:0.773438
- epoch:1, batch:300, cost:0.768483, acc:0.757812
- epoch:1, test\_cost:0.743878, test\_acc:0.760737
- epoch:2, batch:0, cost:0.799179, acc:0.773438
- epoch:2, batch:100, cost:0.684252, acc:0.796875
- epoch:2, batch:200, cost:0.700367, acc:0.773438
- epoch:2, batch:300, cost:0.658518, acc:0.820312
- epoch:2, test\_cost:0.662426, test\_acc:0.787224
- epoch:3, batch:0, cost:0.732371, acc:0.757812
- epoch:3, batch:100, cost:0.621147, acc:0.804688
- epoch:3, batch:200, cost:0.592039, acc:0.820312
- epoch:3, batch:300, cost:0.566666, acc:0.835938
- epoch:3, test\_cost:0.613746, test\_acc:0.799462
- epoch:4, batch:0, cost:0.477513, acc:0.835938
- epoch:4, batch:100, cost:0.510523, acc:0.820312
- epoch:4, batch:200, cost:0.635614, acc:0.765625
- epoch:4, batch:300, cost:0.487310, acc:0.820312
- epoch:4, test\_cost:0.579772, test\_acc:0.810488
- epoch:5, batch:0, cost:0.564993, acc:0.796875
- epoch:5, batch:100, cost:0.406841, acc:0.882812
- epoch:5, batch:200, cost:0.444392, acc:0.867188
- epoch:5, batch:300, cost:0.729928, acc:0.765625
- epoch:5, test\_cost:0.549828, test\_acc:0.820210
- epoch:6, batch:0, cost:0.676264, acc:0.757812
- epoch:6, batch:100, cost:0.596377, acc:0.859375
- epoch:6, batch:200, cost:0.532443, acc:0.773438
- epoch:6, batch:300, cost:0.534213, acc:0.820312
- epoch:6, test\_cost:0.529354, test\_acc:0.824554
- epoch:7, batch:0, cost:0.426611, acc:0.867188
- epoch:7, batch:100, cost:0.474698, acc:0.851562
- epoch:7, batch:200, cost:0.529180, acc:0.851562
- epoch:7, batch:300, cost:0.512545, acc:0.820312
- epoch:7, test\_cost:0.510936, test\_acc:0.830113
- epoch:8, batch:0, cost:0.559594, acc:0.859375
- epoch:8, batch:100, cost:0.670438, acc:0.812500
- epoch:8, batch:200, cost:0.519166, acc:0.796875
- epoch:8, batch:300, cost:0.487624, acc:0.835938
- epoch:8, test\_cost:0.495487, test\_acc:0.831584
- epoch:9, batch:0, cost:0.523822, acc:0.875000
- epoch:9, batch:100, cost:0.604066, acc:0.812500
- epoch:9, batch:200, cost:0.420712, acc:0.835938
- epoch:9, batch:300, cost:0.490420, acc:0.867188
- epoch:9, test\_cost:0.480619, test\_acc:0.838099
- epoch:10, batch:0, cost:0.337900, acc:0.875000
- epoch:10, batch:100, cost:0.513283, acc:0.820312
- epoch:10, batch:200, cost:0.471449, acc:0.835938
- epoch:10, batch:300, cost:0.387642, acc:0.859375
- epoch:10, test\_cost:0.470769, test\_acc:0.842698
- epoch:11, batch:0, cost:0.448336, acc:0.835938
- epoch:11, batch:100, cost:0.440521, acc:0.843750
- epoch:11, batch:200, cost:0.436573, acc:0.875000
- epoch:11, batch:300, cost:0.694043, acc:0.781250
- epoch:11, test\_cost:0.460047, test\_acc:0.846606
- epoch:12, batch:0, cost:0.481746, acc:0.882812
- epoch:12, batch:100, cost:0.552047, acc:0.804688
- epoch:12, batch:200, cost:0.386960, acc:0.882812
- epoch:12, batch:300, cost:0.341771, acc:0.882812
- epoch:12, test\_cost:0.450299, test\_acc:0.850769
- epoch:13, batch:0, cost:0.547187, acc:0.859375
- epoch:13, batch:100, cost:0.356574, acc:0.875000
- epoch:13, batch:200, cost:0.379301, acc:0.867188
- epoch:13, batch:300, cost:0.374310, acc:0.898438
- epoch:13, test\_cost:0.440953, test\_acc:0.856413
- epoch:14, batch:0, cost:0.689324, acc:0.843750
- epoch:14, batch:100, cost:0.493557, acc:0.851562
- epoch:14, batch:200, cost:0.332372, acc:0.898438
- epoch:14, batch:300, cost:0.462879, acc:0.867188
- epoch:14, test\_cost:0.434527, test\_acc:0.858323
- epoch:15, batch:0, cost:0.539430, acc:0.859375
- epoch:15, batch:100, cost:0.610829, acc:0.828125
- epoch:15, batch:200, cost:0.567205, acc:0.773438
- epoch:15, batch:300, cost:0.448473, acc:0.851562
- epoch:15, test\_cost:0.429657, test\_acc:0.858929
- epoch:16, batch:0, cost:0.275449, acc:0.914062
- epoch:16, batch:100, cost:0.322028, acc:0.890625
- epoch:16, batch:200, cost:0.433238, acc:0.843750
- epoch:16, batch:300, cost:0.435924, acc:0.875000
- epoch:16, test\_cost:0.420777, test\_acc:0.864314
- epoch:17, batch:0, cost:0.463425, acc:0.859375
- epoch:17, batch:100, cost:0.498210, acc:0.890625
- epoch:17, batch:200, cost:0.347210, acc:0.898438
- epoch:17, batch:300, cost:0.375353, acc:0.867188
- epoch:17, test\_cost:0.414920, test\_acc:0.866483
- epoch:18, batch:0, cost:0.371144, acc:0.906250
- epoch:18, batch:100, cost:0.511294, acc:0.828125
- epoch:18, batch:200, cost:0.431728, acc:0.828125
- epoch:18, batch:300, cost:0.505222, acc:0.843750
- epoch:18, test\_cost:0.412018, test\_acc:0.866136
- epoch:19, batch:0, cost:0.417319, acc:0.859375
- epoch:19, batch:100, cost:0.405875, acc:0.867188
- epoch:19, batch:200, cost:0.466319, acc:0.843750
- epoch:19, batch:300, cost:0.524598, acc:0.820312
- epoch:19, test\_cost:0.408254, test\_acc:0.870387
- epoch:20, batch:0, cost:0.278774, acc:0.921875
- epoch:20, batch:100, cost:0.375402, acc:0.875000
- epoch:20, batch:200, cost:0.512493, acc:0.851562
- epoch:20, batch:300, cost:0.352869, acc:0.867188
- epoch:20, test\_cost:0.402862, test\_acc:0.870646
- epoch:21, batch:0, cost:0.328388, acc:0.890625
- epoch:21, batch:100, cost:0.474930, acc:0.843750
- epoch:21, batch:200, cost:0.279459, acc:0.898438
- epoch:21, batch:300, cost:0.480916, acc:0.843750
- epoch:21, test\_cost:0.398193, test\_acc:0.870476
- epoch:22, batch:0, cost:0.360476, acc:0.914062
- epoch:22, batch:100, cost:0.399123, acc:0.867188
- epoch:22, batch:200, cost:0.330940, acc:0.898438
- epoch:22, batch:300, cost:0.449070, acc:0.851562
- epoch:22, test\_cost:0.396272, test\_acc:0.872644
- epoch:23, batch:0, cost:0.311765, acc:0.882812
- epoch:23, batch:100, cost:0.430598, acc:0.859375
- epoch:23, batch:200, cost:0.371466, acc:0.867188
- epoch:23, batch:300, cost:0.497460, acc:0.859375
- epoch:23, test\_cost:0.391935, test\_acc:0.874990
- epoch:24, batch:0, cost:0.278461, acc:0.921875
- epoch:24, batch:100, cost:0.384332, acc:0.867188
- epoch:24, batch:200, cost:0.687089, acc:0.804688
- epoch:24, batch:300, cost:0.465835, acc:0.835938
- epoch:24, test\_cost:0.386384, test\_acc:0.874905
- epoch:25, batch:0, cost:0.359800, acc:0.914062
- epoch:25, batch:100, cost:0.370942, acc:0.906250
- epoch:25, batch:200, cost:0.343612, acc:0.906250
- epoch:25, batch:300, cost:0.373149, acc:0.859375
- epoch:25, test\_cost:0.385754, test\_acc:0.875249
- epoch:26, batch:0, cost:0.359912, acc:0.859375
- epoch:26, batch:100, cost:0.299233, acc:0.906250
- epoch:26, batch:200, cost:0.321898, acc:0.882812
- epoch:26, batch:300, cost:0.506139, acc:0.820312
- epoch:26, test\_cost:0.382092, test\_acc:0.877597
- epoch:27, batch:0, cost:0.438806, acc:0.882812
- epoch:27, batch:100, cost:0.351698, acc:0.867188
- epoch:27, batch:200, cost:0.413263, acc:0.875000
- epoch:27, batch:300, cost:0.327677, acc:0.875000
- epoch:27, test\_cost:0.379122, test\_acc:0.880460
- epoch:28, batch:0, cost:0.329184, acc:0.921875
- epoch:28, batch:100, cost:0.489258, acc:0.882812
- epoch:28, batch:200, cost:0.375317, acc:0.890625
- epoch:28, batch:300, cost:0.355702, acc:0.859375
- epoch:28, test\_cost:0.377964, test\_acc:0.881066
- epoch:29, batch:0, cost:0.360147, acc:0.882812
- epoch:29, batch:100, cost:0.361545, acc:0.906250
- epoch:29, batch:200, cost:0.535644, acc:0.812500
- epoch:29, batch:300, cost:0.463827, acc:0.789062
- epoch:29, test\_cost:0.374992, test\_acc:0.879156
- epoch:30, batch:0, cost:0.386321, acc:0.843750
- epoch:30, batch:100, cost:0.450116, acc:0.851562
- epoch:30, batch:200, cost:0.380319, acc:0.867188
- epoch:30, batch:300, cost:0.357393, acc:0.914062
- epoch:30, test\_cost:0.372232, test\_acc:0.880198
- epoch:31, batch:0, cost:0.338851, acc:0.882812
- epoch:31, batch:100, cost:0.418707, acc:0.890625
- epoch:31, batch:200, cost:0.349568, acc:0.875000
- epoch:31, batch:300, cost:0.414638, acc:0.882812
- epoch:31, test\_cost:0.373127, test\_acc:0.879245
- epoch:32, batch:0, cost:0.278832, acc:0.906250
- epoch:32, batch:100, cost:0.538143, acc:0.851562
- epoch:32, batch:200, cost:0.418359, acc:0.890625
- epoch:32, batch:300, cost:0.510367, acc:0.875000
- epoch:32, test\_cost:0.370239, test\_acc:0.880896
- epoch:33, batch:0, cost:0.410598, acc:0.835938
- epoch:33, batch:100, cost:0.295002, acc:0.906250
- epoch:33, batch:200, cost:0.430560, acc:0.828125
- epoch:33, batch:300, cost:0.417476, acc:0.859375
- epoch:33, test\_cost:0.367410, test\_acc:0.881155
- epoch:34, batch:0, cost:0.337740, acc:0.937500
- epoch:34, batch:100, cost:0.304080, acc:0.906250
- epoch:34, batch:200, cost:0.359049, acc:0.890625
- epoch:34, batch:300, cost:0.373999, acc:0.890625
- epoch:34, test\_cost:0.367002, test\_acc:0.880113
- epoch:35, batch:0, cost:0.411581, acc:0.898438
- epoch:35, batch:100, cost:0.400797, acc:0.851562
- epoch:35, batch:200, cost:0.482271, acc:0.828125
- epoch:35, batch:300, cost:0.340450, acc:0.890625
- epoch:35, test\_cost:0.363663, test\_acc:0.883068
- epoch:36, batch:0, cost:0.338912, acc:0.875000
- epoch:36, batch:100, cost:0.416916, acc:0.867188
- epoch:36, batch:200, cost:0.313621, acc:0.882812
- epoch:36, batch:300, cost:0.677497, acc:0.796875
- epoch:36, test\_cost:0.361819, test\_acc:0.882983
- epoch:37, batch:0, cost:0.329249, acc:0.867188
- epoch:37, batch:100, cost:0.375915, acc:0.890625
- epoch:37, batch:200, cost:0.290267, acc:0.906250
- epoch:37, batch:300, cost:0.388264, acc:0.859375
- epoch:37, test\_cost:0.363713, test\_acc:0.880025
- epoch:38, batch:0, cost:0.452093, acc:0.875000
- epoch:38, batch:100, cost:0.237014, acc:0.898438
- epoch:38, batch:200, cost:0.334976, acc:0.898438
- epoch:38, batch:300, cost:0.386618, acc:0.875000
- epoch:38, test\_cost:0.357681, test\_acc:0.884889
- epoch:39, batch:0, cost:0.397014, acc:0.867188
- epoch:39, batch:100, cost:0.387132, acc:0.882812
- epoch:39, batch:200, cost:0.262646, acc:0.921875
- epoch:39, batch:300, cost:0.295718, acc:0.906250
- epoch:39, test\_cost:0.358814, test\_acc:0.884542
- epoch:40, batch:0, cost:0.336061, acc:0.875000
- epoch:40, batch:100, cost:0.393282, acc:0.867188
- epoch:40, batch:200, cost:0.453071, acc:0.867188
- epoch:40, batch:300, cost:0.276213, acc:0.921875
- epoch:40, test\_cost:0.355846, test\_acc:0.886278
- epoch:41, batch:0, cost:0.362588, acc:0.867188
- epoch:41, batch:100, cost:0.293396, acc:0.914062
- epoch:41, batch:200, cost:0.351766, acc:0.890625
- epoch:41, batch:300, cost:0.437711, acc:0.820312
- epoch:41, test\_cost:0.356017, test\_acc:0.886799
- epoch:42, batch:0, cost:0.431722, acc:0.843750
- epoch:42, batch:100, cost:0.296809, acc:0.914062
- epoch:42, batch:200, cost:0.300333, acc:0.898438
- epoch:42, batch:300, cost:0.392034, acc:0.859375
- epoch:42, test\_cost:0.354504, test\_acc:0.885580
- epoch:43, batch:0, cost:0.237395, acc:0.945312
- epoch:43, batch:100, cost:0.274653, acc:0.914062
- epoch:43, batch:200, cost:0.320165, acc:0.898438
- epoch:43, batch:300, cost:0.233366, acc:0.937500
- epoch:43, test\_cost:0.352862, test\_acc:0.885410
- epoch:44, batch:0, cost:0.309431, acc:0.953125
- epoch:44, batch:100, cost:0.371803, acc:0.843750
- epoch:44, batch:200, cost:0.309721, acc:0.898438
- epoch:44, batch:300, cost:0.330030, acc:0.898438
- epoch:44, test\_cost:0.348967, test\_acc:0.888017
- epoch:45, batch:0, cost:0.382172, acc:0.890625
- epoch:45, batch:100, cost:0.292855, acc:0.929688
- epoch:45, batch:200, cost:0.445127, acc:0.898438
- epoch:45, batch:300, cost:0.365554, acc:0.890625
- epoch:45, test\_cost:0.352218, test\_acc:0.883932
- epoch:46, batch:0, cost:0.424743, acc:0.898438
- epoch:46, batch:100, cost:0.382699, acc:0.859375
- epoch:46, batch:200, cost:0.319472, acc:0.914062
- epoch:46, batch:300, cost:0.414162, acc:0.859375
- epoch:46, test\_cost:0.349987, test\_acc:0.885498
- epoch:47, batch:0, cost:0.304131, acc:0.890625
- epoch:47, batch:100, cost:0.386861, acc:0.890625
- epoch:47, batch:200, cost:0.608894, acc:0.820312
- epoch:47, batch:300, cost:0.281832, acc:0.898438
- epoch:47, test\_cost:0.349286, test\_acc:0.888276
- epoch:48, batch:0, cost:0.406423, acc:0.882812
- epoch:48, batch:100, cost:0.398680, acc:0.898438
- epoch:48, batch:200, cost:0.291706, acc:0.914062
- epoch:48, batch:300, cost:0.358105, acc:0.875000
- epoch:48, test\_cost:0.348130, test\_acc:0.888361
- epoch:49, batch:0, cost:0.284720, acc:0.914062
- epoch:49, batch:100, cost:0.341173, acc:0.898438
- epoch:49, batch:200, cost:0.341595, acc:0.859375
- epoch:49, batch:300, cost:0.442754, acc:0.820312
- epoch:49, test\_cost:0.347218, test\_acc:0.886012
- epoch:50, batch:0, cost:0.311721, acc:0.906250
- epoch:50, batch:100, cost:0.326822, acc:0.875000
- epoch:50, batch:200, cost:0.331799, acc:0.898438
- epoch:50, batch:300, cost:0.426647, acc:0.851562
- epoch:50, test\_cost:0.347288, test\_acc:0.888535
- epoch:51, batch:0, cost:0.389481, acc:0.867188
- epoch:51, batch:100, cost:0.289127, acc:0.906250
- epoch:51, batch:200, cost:0.328051, acc:0.929688
- epoch:51, batch:300, cost:0.426396, acc:0.890625
- epoch:51, test\_cost:0.344246, test\_acc:0.889839
- epoch:52, batch:0, cost:0.288156, acc:0.906250
- epoch:52, batch:100, cost:0.298805, acc:0.906250
- epoch:52, batch:200, cost:0.371176, acc:0.921875
- epoch:52, batch:300, cost:0.389306, acc:0.875000
- epoch:52, test\_cost:0.345692, test\_acc:0.891224
- epoch:53, batch:0, cost:0.425932, acc:0.890625
- epoch:53, batch:100, cost:0.415528, acc:0.882812
- epoch:53, batch:200, cost:0.434767, acc:0.867188
- epoch:53, batch:300, cost:0.331441, acc:0.914062
- epoch:53, test\_cost:0.340924, test\_acc:0.890101
- epoch:54, batch:0, cost:0.260270, acc:0.906250
- epoch:54, batch:100, cost:0.305412, acc:0.898438
- epoch:54, batch:200, cost:0.330370, acc:0.906250
- epoch:54, batch:300, cost:0.334084, acc:0.898438
- epoch:54, test\_cost:0.341799, test\_acc:0.892010
- epoch:55, batch:0, cost:0.239946, acc:0.937500
- epoch:55, batch:100, cost:0.510334, acc:0.898438
- epoch:55, batch:200, cost:0.331789, acc:0.898438
- epoch:55, batch:300, cost:0.273344, acc:0.898438
- epoch:55, test\_cost:0.341348, test\_acc:0.889403
- epoch:56, batch:0, cost:0.288282, acc:0.914062
- epoch:56, batch:100, cost:0.384843, acc:0.898438
- epoch:56, batch:200, cost:0.391903, acc:0.867188
- epoch:56, batch:300, cost:0.352458, acc:0.882812
- epoch:56, test\_cost:0.338860, test\_acc:0.891054
- epoch:57, batch:0, cost:0.434810, acc:0.828125
- epoch:57, batch:100, cost:0.257800, acc:0.953125
- epoch:57, batch:200, cost:0.283473, acc:0.921875
- epoch:57, batch:300, cost:0.337173, acc:0.867188
- epoch:57, test\_cost:0.339060, test\_acc:0.891575
- epoch:58, batch:0, cost:0.240891, acc:0.898438
- epoch:58, batch:100, cost:0.390225, acc:0.875000
- epoch:58, batch:200, cost:0.393483, acc:0.843750
- epoch:58, batch:300, cost:0.289487, acc:0.890625
- epoch:58, test\_cost:0.337302, test\_acc:0.892269
- epoch:59, batch:0, cost:0.210337, acc:0.960938
- epoch:59, batch:100, cost:0.423231, acc:0.867188
- epoch:59, batch:200, cost:0.319490, acc:0.921875
- epoch:59, batch:300, cost:0.451494, acc:0.859375
- epoch:59, test\_cost:0.336483, test\_acc:0.893137
- epoch:60, batch:0, cost:0.231775, acc:0.937500
- epoch:60, batch:100, cost:0.295306, acc:0.906250
- epoch:60, batch:200, cost:0.378960, acc:0.859375
- epoch:60, batch:300, cost:0.350808, acc:0.843750
- epoch:60, test\_cost:0.335058, test\_acc:0.894267
- epoch:61, batch:0, cost:0.440865, acc:0.867188
- epoch:61, batch:100, cost:0.270725, acc:0.882812
- epoch:61, batch:200, cost:0.398181, acc:0.851562
- epoch:61, batch:300, cost:0.363882, acc:0.921875
- epoch:61, test\_cost:0.336761, test\_acc:0.892875
- epoch:62, batch:0, cost:0.321757, acc:0.898438
- epoch:62, batch:100, cost:0.330311, acc:0.890625
- epoch:62, batch:200, cost:0.406124, acc:0.851562
- epoch:62, batch:300, cost:0.275819, acc:0.898438
- epoch:62, test\_cost:0.342463, test\_acc:0.891565
- epoch:63, batch:0, cost:0.321822, acc:0.898438
- epoch:63, batch:100, cost:0.322195, acc:0.882812
- epoch:63, batch:200, cost:0.432605, acc:0.882812
- epoch:63, batch:300, cost:0.377368, acc:0.898438
- epoch:63, test\_cost:0.333785, test\_acc:0.895221
- epoch:64, batch:0, cost:0.247617, acc:0.882812
- epoch:64, batch:100, cost:0.231372, acc:0.921875
- epoch:64, batch:200, cost:0.336805, acc:0.867188
- epoch:64, batch:300, cost:0.274635, acc:0.898438
- epoch:64, test\_cost:0.332033, test\_acc:0.894179
- epoch:65, batch:0, cost:0.241076, acc:0.906250
- epoch:65, batch:100, cost:0.377462, acc:0.906250
- epoch:65, batch:200, cost:0.297226, acc:0.882812
- epoch:65, batch:300, cost:0.440397, acc:0.867188
- epoch:65, test\_cost:0.330794, test\_acc:0.897045
- epoch:66, batch:0, cost:0.266126, acc:0.898438
- epoch:66, batch:100, cost:0.390715, acc:0.859375
- epoch:66, batch:200, cost:0.292437, acc:0.914062
- epoch:66, batch:300, cost:0.395078, acc:0.867188
- epoch:66, test\_cost:0.330902, test\_acc:0.895221
- epoch:67, batch:0, cost:0.301438, acc:0.929688
- epoch:67, batch:100, cost:0.388324, acc:0.898438
- epoch:67, batch:200, cost:0.439915, acc:0.890625
- epoch:67, batch:300, cost:0.310547, acc:0.867188
- epoch:67, test\_cost:0.330386, test\_acc:0.896521
- epoch:68, batch:0, cost:0.243119, acc:0.929688
- epoch:68, batch:100, cost:0.447522, acc:0.875000
- epoch:68, batch:200, cost:0.470691, acc:0.882812
- epoch:68, batch:300, cost:0.296465, acc:0.882812
- epoch:68, test\_cost:0.326098, test\_acc:0.896266
- epoch:69, batch:0, cost:0.260604, acc:0.898438
- epoch:69, batch:100, cost:0.417193, acc:0.882812
- epoch:69, batch:200, cost:0.483119, acc:0.835938
- epoch:69, batch:300, cost:0.405713, acc:0.875000
- epoch:69, test\_cost:0.328661, test\_acc:0.896957
- epoch:70, batch:0, cost:0.300975, acc:0.882812
- epoch:70, batch:100, cost:0.199427, acc:0.945312
- epoch:70, batch:200, cost:0.207260, acc:0.937500
- epoch:70, batch:300, cost:0.199148, acc:0.914062
- epoch:70, test\_cost:0.327545, test\_acc:0.894958
- epoch:71, batch:0, cost:0.281955, acc:0.914062
- epoch:71, batch:100, cost:0.267508, acc:0.914062
- epoch:71, batch:200, cost:0.561389, acc:0.828125
- epoch:71, batch:300, cost:0.377676, acc:0.867188
- epoch:71, test\_cost:0.325637, test\_acc:0.897740
- epoch:72, batch:0, cost:0.348661, acc:0.890625
- epoch:72, batch:100, cost:0.346154, acc:0.898438
- epoch:72, batch:200, cost:0.447819, acc:0.867188
- epoch:72, batch:300, cost:0.342514, acc:0.929688
- epoch:72, test\_cost:0.325294, test\_acc:0.897304
- epoch:73, batch:0, cost:0.223638, acc:0.929688
- epoch:73, batch:100, cost:0.394560, acc:0.859375
- epoch:73, batch:200, cost:0.341260, acc:0.890625
- epoch:73, batch:300, cost:0.283185, acc:0.898438
- epoch:73, test\_cost:0.326340, test\_acc:0.895394
- epoch:74, batch:0, cost:0.371942, acc:0.921875
- epoch:74, batch:100, cost:0.333636, acc:0.882812
- epoch:74, batch:200, cost:0.397030, acc:0.875000
- epoch:74, batch:300, cost:0.392802, acc:0.875000
- epoch:74, test\_cost:0.322571, test\_acc:0.896089
- epoch:75, batch:0, cost:0.275930, acc:0.921875
- epoch:75, batch:100, cost:0.263152, acc:0.914062
- epoch:75, batch:200, cost:0.296550, acc:0.898438
- epoch:75, batch:300, cost:0.402121, acc:0.898438
- epoch:75, test\_cost:0.320611, test\_acc:0.897134
- epoch:76, batch:0, cost:0.279775, acc:0.921875
- epoch:76, batch:100, cost:0.439274, acc:0.843750
- epoch:76, batch:200, cost:0.330266, acc:0.898438
- epoch:76, batch:300, cost:0.418308, acc:0.851562
- epoch:76, test\_cost:0.320242, test\_acc:0.900429
- epoch:77, batch:0, cost:0.320668, acc:0.890625
- epoch:77, batch:100, cost:0.168939, acc:0.960938
- epoch:77, batch:200, cost:0.244379, acc:0.953125
- epoch:77, batch:300, cost:0.621534, acc:0.875000
- epoch:77, test\_cost:0.319756, test\_acc:0.900865
- epoch:78, batch:0, cost:0.284392, acc:0.914062
- epoch:78, batch:100, cost:0.309243, acc:0.890625
- epoch:78, batch:200, cost:0.273962, acc:0.945312
- epoch:78, batch:300, cost:0.311928, acc:0.906250
- epoch:78, test\_cost:0.318491, test\_acc:0.901818
- epoch:79, batch:0, cost:0.242170, acc:0.898438
- epoch:79, batch:100, cost:0.315753, acc:0.875000
- epoch:79, batch:200, cost:0.252874, acc:0.937500
- epoch:79, batch:300, cost:0.447730, acc:0.812500
- epoch:79, test\_cost:0.318828, test\_acc:0.900603
- 模型保存成功.
【推理预测】
- model_save_dir = "model/"
-
- def get_data(sentence): # 将传入的句子根据字典中的值进行编码
- with open(dict_file_path, "r", encoding="utf-8") as f:
- dict_txt = eval(f.readlines()[0])
-
- ret = [] # 编码结果
- keys = dict_txt.keys()
- for w in sentence: # 取出每个字
- if not w in keys: # 字不在字典中
- w = "<unk>"
- ret.append(int(dict_txt[w]))
- return ret
-
- # 执行器
- place = fluid.CPUPlace()
- exe = fluid.Executor(place)
- exe.run(fluid.default_startup_program())
-
- infer_program, feed_names, target_var = \
- fluid.io.load_inference_model(model_save_dir, exe)
-
- texts = [] # 存放待预测句子
-
- data1 = get_data("在获得诺贝尔文学奖7年之后,莫言15日晚间在山西汾阳贾家庄如是说")
- data2 = get_data("综合'今日美国'、《世界日报》等当地媒体报道,芝加哥河滨警察局表示")
- data3 = get_data("中国队2022年冬奥会表现优秀")
- data4 = get_data("中国人民银行今日发布通知,降低准备金率,预计释放4000亿流动性")
- data5 = get_data("10月20日,第六届世界互联网大会正式开幕")
- data6 = get_data("同一户型,为什么高层比低层要贵那么多?")
- data7 = get_data("揭秘A股周涨5%资金动向:追捧2类股,抛售600亿香饽饽")
- data8 = get_data("宋慧乔陷入感染危机,前夫宋仲基不戴口罩露面,身处国外神态轻松")
- data9 = get_data("此盆栽花很好养,花美似牡丹,三季开花,南北都能养,很值得栽培") # 不属于任何一个类别
-
- texts.append(data1)
- texts.append(data2)
- texts.append(data3)
- texts.append(data4)
- texts.append(data5)
- texts.append(data6)
- texts.append(data7)
- texts.append(data8)
- texts.append(data9)
-
- base_shape = [[len(c) for c in texts]] # 计算每个句子长度
- tensor_words = fluid.create_lod_tensor(texts, base_shape, place)
- result = exe.run(infer_program,
- feed={feed_names[0]: tensor_words},
- fetch_list=target_var)
- names = ["文化", "娱乐", "体育", "财经", "房产","汽车", "教育", "科技", "国际", "证券"]
- for r in result[0]:
- idx = np.argmax(r) # 取出最大值的索引
- print("预测结果:", names[idx], " 概率:", r[idx])
输出
- 预测结果: 财经 概率: 0.81440145
- 预测结果: 娱乐 概率: 1.0
- 预测结果: 财经 概率: 1.0
- 预测结果: 汽车 概率: 0.9996093
- 预测结果: 文化 概率: 0.9404757
- 预测结果: 娱乐 概率: 0.8715788
- 预测结果: 房产 概率: 0.9625704
- 预测结果: 科技 概率: 0.985617
- 预测结果: 房产 概率: 1.0
文章涉及到的数据资源链接如下:news_classify_data.zip - 蓝奏云文件大小:2.7 M|https://wwt.lanzoum.com/iTF6N1q1rntiCopyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。