当前位置:   article > 正文

K邻近算法实现短信文本分类_twitter_w2v.train(x_train,total_examples=1, epochs

twitter_w2v.train(x_train,total_examples=1, epochs=1)

首先,导入需要的库

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2022/3/30 21:51
  4. #导包
  5. import warnings
  6. from pandas import DataFrame, concat
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import jieba
  11. plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
  12. plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
  13. warnings.filterwarnings('ignore') #忽视警告

然后定义导入数据的函数。

  1. def get_data(data_file): #定义一个读取文件的函数
  2. """
  3. 解析文本数据
  4. :param data_file: 数据文件
  5. :return: 分词结果,标记
  6. """
  7. target = []
  8. data = []
  9. with open(data_file, 'r', encoding='utf-8') as f: #打开文件
  10. for line in f.readlines(): #按行读取
  11. line = line.strip().split("\t") #滤除行首行尾空格,以\t作为分隔符,对这行进行分解
  12. if len(line) == 1:
  13. continue
  14. target.append(int(line[0])) #\t之前的是标签
  15. data.append(line[1]) #\t之后的是文本数据
  16. # data = list(map(jieba.lcut, data))
  17. # data = [" ".join(d) for d in data]
  18. return data, target #返回文本和标签列表
  19. # 导入数据
  20. print('正在导入数据!')
  21. DATAFILE = "data.txt"
  22. data, target = get_data(DATAFILE) #读取数据
  23. data = DataFrame(data) #将文本转换化成表格形式
  24. target = DataFrame(target) #将标签转化成表格形式
  25. df = concat([data, target], axis=1) #拼接为一个表格
  26. df.columns = ['text', 'label']
  27. print('数据导入完成!显示前5行数据:')
  28. print(df.head(5))

对数据进行预处理。预处理内容包括用jieba库将句子分词处理成单词,然后再将单词转换为词向量,进而将每句话转换成了一个向量,便于输入网络进行训练。

  1. # 数据预处理
  2. print('开始数据预处理')
  3. print('1、按标签打乱训练数据!')
  4. df = df.sample(frac=1) # 将正面文本数据与负面文本数据进行打乱
  5. print('数据已按标签打乱!显示前5行:')
  6. print(df.head(5)) #显示前5行
  7. # 分词处理
  8. import jieba
  9. word_cut = lambda x: jieba.lcut(x)
  10. print('2、分词处理!')
  11. df['words'] = df["text"].apply(word_cut) #分词处理
  12. print('分词已完成!显示前5行:')
  13. print(df.head())
  14. # 去除停用词
  15. with open("hit_stopwords.txt", "r", encoding='utf-8') as f:
  16. stopwords = f.readlines()
  17. stopwords_list = []
  18. for each in stopwords:
  19. stopwords_list.append(each.strip('\n'))
  20. # 添加自定义停用词
  21. stopwords_list += ["…", "也", ".", "都", "是", "而", "了"," "]
  22. def remove_stopwords(ls): # 去除停用词
  23. return [word for word in ls if word not in stopwords_list]
  24. print('3、去除停用词处理!')
  25. df['去除停用词后的数据'] = df["words"].apply(lambda x: remove_stopwords(x))
  26. print('去除停用词完成!显示前5行:')
  27. print(df.head(5))
  28. # 词词向量处理
  29. from gensim.models.word2vec import Word2Vec
  30. x = df["去除停用词后的数据"] #处理对象是去除停用词后的数据
  31. # 训练 Word2Vec 浅层神经网络模型
  32. w2v = Word2Vec(vector_size=300, # 是指特征向量的维度,默认为100。
  33. min_count=10) # 可以对字典做截断. 词频少于min_count次数的单词会被丢弃掉, 默认值为5。
  34. w2v.build_vocab(x)
  35. w2v.train(x,
  36. total_examples=w2v.corpus_count,
  37. epochs=20)
  38. # 保存 Word2Vec 模型及词向量
  39. w2v.save('w2v_model.pkl')
  40. # 将文本转化为向量
  41. def average_vec(text):
  42. vec = np.zeros(300).reshape((1, 300)) #每个单词对应一个1行300列的向量
  43. for word in text:
  44. try:
  45. vec += w2v.wv[word].reshape((1, 300)) #每句话的向量等于每个单词的向量相加
  46. except KeyError:
  47. continue
  48. return vec
  49. # 将词向量保存为 Ndarray
  50. print('4、词向量处理!')
  51. x_vec = np.concatenate([average_vec(z) for z in x]) #把训练数据处理成词向量
  52. y = df['label'].values
  53. print('词向量处理完成!')
  54. print('数据预处理完成!')

接下来开始训练我们的网络,这里使用的是sklearn中的K邻近算法。

  1. #划分训练集,测试集
  2. from sklearn.model_selection import train_test_split
  3. X_train,X_test,y_train,y_test = train_test_split(x_vec,y,test_size=0.3)
  4. ####定义模型
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.metrics import mean_squared_error
  7. from sklearn.neighbors import KNeighborsClassifier
  8. from sklearn.model_selection import train_test_split
  9. def train(k=5): #定义训练函数
  10. # 创建分类器
  11. clf = KNeighborsClassifier(n_neighbors=k) #k取5
  12. # 训练数据
  13. clf.fit(X_train, y_train)
  14. # 测试数据
  15. print('训练完成!')
  16. print('开始在验证集测试准确率!')
  17. predictions = clf.predict(X_test)
  18. print('Accuracy:', accuracy_score(y_test, predictions))
  19. return clf,predictions
  20. print('正在训练模型!')
  21. clf,predictions=train(k=6)

模型训练好后,我们拿来测试。

这里使用了两种测试方法,mode=1时我们使用手动输入的方式进行测试 ,这样方便演示这个模型的功能。mode=2时我们可以直接使用test.txt文件中的数据进行识别,并计算准确率和识别消耗的时间。

  1. num=10 #手动输入条数
  2. def one_pridect(text):
  3. words = word_cut(text)
  4. words = remove_stopwords(words)
  5. words = average_vec(words)
  6. result = clf.predict(words.reshape(1, -1))
  7. return result
  8. import time
  9. mode=2 #mode=1时手动输入,用于演示, mode=2时,计算100条测试样本
  10. if mode == 1:
  11. result_list=[]
  12. time_all=0
  13. correct=0
  14. for i in range(num):
  15. a=input('测试数据:')
  16. time_start = time.time() # 记录开始时间
  17. result=one_pridect(a)
  18. if result==0:
  19. print('正常语句')
  20. else:
  21. print('内含诈骗信息')
  22. time_end = time.time() # 记录结束时间
  23. time_sum = time_end - time_start # 计算的时间差为程序的执行时间,单位为秒/s
  24. print('识别用时%f' %time_sum)
  25. result_list.append(result)
  26. time_all+=time_sum
  27. error=np.sum(result_list)
  28. correct=num-error
  29. accuracy=accuracy_score(y_test, predictions)
  30. print('正常语句数量:%d' %correct)
  31. print('内含诈骗信息语句数量:%d' % error)
  32. print('总用时:%f' %time_all)
  33. elif mode==2 :
  34. TESTFILE="test.txt"
  35. data, target = get_data(TESTFILE) # 读取数据
  36. data = DataFrame(data) # 将文本转换化成表格形式
  37. target = DataFrame(target) # 将标签转化成表格形式
  38. df = concat([data, target], axis=1) # 拼接为一个表格
  39. df.columns = ['text', 'label']
  40. df['words'] = df["text"].apply(word_cut) # 分词处理
  41. df['去除停用词后的数据'] = df["words"].apply(lambda x: remove_stopwords(x))
  42. x = df["去除停用词后的数据"] # 处理对象是去除停用词后的数据
  43. x_vec = np.concatenate([average_vec(z) for z in x]) # 把训练数据处理成词向量
  44. y = df['label'].values
  45. time_start = time.time() # 记录开始时间
  46. predictions = clf.predict(x_vec)
  47. time_end = time.time() # 记录结束时间
  48. time_sum = time_end - time_start # 计算的时间差为程序的执行时间,单位为秒/s
  49. print('识别100条数据用时%f' % time_sum)
  50. print('Accuracy:', accuracy_score(y, predictions)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/435613
推荐阅读
  

闽ICP备14008679号