当前位置:   article > 正文

第N6周:使用Word2vec实现文本分类

第N6周:使用Word2vec实现文本分类
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. from torchvision import transforms,datasets
  5. import os,PIL,pathlib,warnings
  6. #忽略警告信息
  7. warnings.filterwarnings("ignore")
  8. # win10系统
  9. device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
  10. device
  11. import pandas as pd
  12. # 加载自定义中文数据
  13. train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
  14. train_data.head()
  15. # 构造数据集迭代器
  16. def coustom_data_iter(texts,labels):
  17. for x,y in zip(texts,labels):
  18. yield x,y
  19. x = train_data[0].values[:]
  20. #多类标签的one-hot展开
  21. y = train_data[1].values[:]
  22. from gensim.models.word2vec import Word2Vec
  23. import numpy as np
  24. #训练word2Vec浅层神经网络模型
  25. w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。
  26. ,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5
  27. w2v.build_vocab(x)
  28. w2v.train(x,total_examples=w2v.corpus_count,epochs=20)
  29. # 将文本转化为向量
  30. def average_vec(text):
  31. vec =np.zeros(100).reshape((1,100))
  32. for word in text:
  33. try:
  34. vec +=w2v.wv[word].reshape((1,100))
  35. except KeyError:
  36. continue
  37. return vec
  38. #将词向量保存为Ndarray
  39. x_vec= np.concatenate([average_vec(z)for z in x])
  40. #保存Word2Vec模型及词向量
  41. w2v.save('data/w2v_model.pk1')
  42. train_iter= coustom_data_iter(x_vec,y)
  43. len(x),len(x_vec)
  44. label_name =list(set(train_data[1].values[:]))
  45. print(label_name)
  46. text_pipeline =lambda x:average_vec(x)
  47. label_pipeline =lambda x:label_name.index(x)
  48. text_pipeline("你在干嘛")
  49. label_pipeline("Travel-Query")
  50. from torch.utils.data import DataLoader
  51. def collate_batch(batch):
  52. label_list,text_list=[],[]
  53. for(_text,_label)in batch:
  54. # 标签列表
  55. label_list.append(label_pipeline(_label))
  56. # 文本列表
  57. processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)
  58. text_list.append(processed_text)
  59. label_list = torch.tensor(label_list,dtype=torch.int64)
  60. text_list = torch.cat(text_list)
  61. return text_list.to(device),label_list.to(device)
  62. # 数据加载器,调用示例
  63. dataloader = DataLoader(train_iter,batch_size=8,
  64. shuffle =False,
  65. collate_fn=collate_batch)
  66. from torch import nn
  67. class TextclassificationModel(nn.Module):
  68. def __init__(self,num_class):
  69. super(TextclassificationModel,self).__init__()
  70. self.fc = nn.Linear(100,num_class)
  71. def forward(self,text):
  72. return self.fc(text)
  73. num_class =len(label_name)
  74. vocab_size =100000
  75. em_size=12
  76. model= TextclassificationModel(num_class).to(device)
  77. import time
  78. def train(dataloader):
  79. model.train()#切换为训练模式
  80. total_acc,train_loss,total_count =0,0,0
  81. log_interval=50
  82. start_time= time.time()
  83. for idx,(text,label)in enumerate(dataloader):
  84. predicted_label= model(text)
  85. # grad属性归零
  86. optimizer.zero_grad()
  87. loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,label
  88. loss.backward()
  89. #反向传播
  90. torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪
  91. optimizer.step()#每一步自动更新
  92. #记录acc与loss
  93. total_acc+=(predicted_label.argmax(1)==label).sum().item()
  94. train_loss += loss.item()
  95. total_count += label.size(0)
  96. if idx % log_interval==0 and idx>0:
  97. elapsed =time.time()-start_time
  98. print('Iepoch {:1d}I{:4d}/{:4d} batches'
  99. '|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))
  100. total_acc,train_loss,total_count =0,0,0
  101. start_time = time.time()
  102. def evaluate(dataloader):
  103. model.eval()#切换为测试模式
  104. total_acc,train_loss,total_count =0,0,0
  105. with torch.no_grad():
  106. for idx,(text,label)in enumerate(dataloader):
  107. predicted_label= model(text)
  108. loss = criterion(predicted_label,label)# 计算loss值
  109. # 记录测试数据
  110. total_acc+=(predicted_label.argmax(1)== label).sum().item()
  111. train_loss += loss.item()
  112. total_count += label.size(0)
  113. return total_acc/total_count,train_loss/total_count
  114. from torch.utils.data.dataset import random_split
  115. from torchtext.data.functional import to_map_style_dataset
  116. # 超参数
  117. EPOCHS=10#epoch
  118. LR=5 #学习率
  119. BATCH_SIZE=64 # batch size for training
  120. criterion = torch.nn.CrossEntropyLoss()
  121. optimizer= torch.optim.SGD(model.parameters(),lr=LR)
  122. scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
  123. total_accu = None
  124. # 构建数据集
  125. train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
  126. train_dataset = to_map_style_dataset(train_iter)
  127. split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
  128. train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
  129. shuffle=True,collate_fn=collate_batch)
  130. valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
  131. shuffle=True,collate_fn=collate_batch)
  132. for epoch in range(1,EPOCHS+1):
  133. epoch_start_time = time.time()
  134. train(train_dataloader)
  135. val_acc,val_loss = evaluate(valid_dataloader)
  136. # 获取当前的学习率
  137. lr =optimizer.state_dict()['param_groups'][0]['1r']
  138. if total_accu is not None and total_accu>val_acc:
  139. scheduler.step()
  140. else:
  141. total_accu = val_acc
  142. print('-'*69)
  143. print('|epoch {:1d}|time:{:4.2f}s |'
  144. 'valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,
  145. time.time()-epoch_start_time,
  146. val_acc,val_loss,lr))
  147. print('-'*69)
  148. # test_acc,test_loss =evaluate(valid_dataloader)
  149. # print('模型准确率为:{:5.4f}'.format(test_acc))
  150. #
  151. #
  152. # def predict(text,text_pipeline):
  153. # with torch.no_grad():
  154. # text = torch.tensor(text_pipeline(text),dtype=torch.float32)
  155. # print(text.shape)
  156. # output = model(text)
  157. # return output.argmax(1).item()
  158. # # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
  159. # ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
  160. # model=model.to("cpu")
  161. # print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

  1. [[-0.85472693 0.96605204 1.5058695 -0.06065784 -2.10079319 -0.12021151
  2. 1.41170089 2.00004494 0.90861696 -0.62710127 -0.62408304 -3.80595499
  3. 1.02797993 -0.45584389 0.54715634 1.70490362 2.33389823 -1.99607518
  4. 4.34822938 -0.76296186 2.73265275 -1.15046433 0.82106878 -0.32701646
  5. -0.50515595 -0.37742117 -2.02331601 -1.365334 1.48786476 -1.6394971
  6. 1.59438308 2.23569647 -0.00500725 -0.65070192 0.07377997 0.01777986
  7. -1.35580809 3.82080549 -2.19764423 1.06595343 0.99296588 0.58972518
  8. -0.33535255 2.15471306 -0.52244038 1.00874437 1.28869729 -0.72208139
  9. -2.81094289 2.2614549 0.20799019 -2.36187895 -0.94019454 0.49448857
  10. -0.68613767 -0.79071895 0.47535057 -0.78339124 -0.71336574 -0.27931567
  11. 1.0514895 -1.76352624 1.93158554 -0.85853558 -0.65540617 1.3612217
  12. -1.39405773 1.18187538 1.31730198 -0.02322496 0.14652854 0.22249881
  13. 2.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.27722868
  14. 2.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767 2.12093259
  15. -1.2326943 -1.89571355 2.3295732 -0.53244872 -0.67313893 -0.80814604
  16. 0.86987564 -1.31373079 1.33797717 1.02223087 0.5817025 -0.83535647
  17. 0.97088164 2.09045361 -2.57758138 0.07126901]]
  18. 6

输出结果并非为0

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/344149
推荐阅读
相关标签
  

闽ICP备14008679号