当前位置:   article > 正文

pytorch实现用CNN和LSTM对文本进行分类

pytorch实现用CNN和LSTM对文本进行分类
model.py:
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import torch
  4. from torch import nn
  5. import numpy as np
  6. from torch.autograd import Variable
  7. import torch.nn.functional as F
  8. class TextRNN(nn.Module):
  9. """文本分类,RNN模型"""
  10. def __init__(self):
  11. super(TextRNN, self).__init__()
  12. # 三个待输入的数据
  13. self.embedding = nn.Embedding(5000, 64) # 进行词嵌入
  14. # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
  15. self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
  16. self.f1 = nn.Sequential(nn.Linear(256,128),
  17. nn.Dropout(0.8),
  18. nn.ReLU())
  19. self.f2 = nn.Sequential(nn.Linear(128,10),
  20. nn.Softmax())
  21. def forward(self, x):
  22. x = self.embedding(x)
  23. x,_ = self.rnn(x)
  24. x = F.dropout(x,p=0.8)
  25. x = self.f1(x[:,-1,:])
  26. return self.f2(x)
  27. class TextCNN(nn.Module):
  28. def __init__(self):
  29. super(TextCNN, self).__init__()
  30. self.embedding = nn.Embedding(5000,64)
  31. self.conv = nn.Conv1d(64,256,5)
  32. self.f1 = nn.Sequential(nn.Linear(256*596, 128),
  33. nn.ReLU())
  34. self.f2 = nn.Sequential(nn.Linear(128, 10),
  35. nn.Softmax())
  36. def forward(self, x):
  37. x = self.embedding(x)
  38. x = x.detach().numpy()
  39. x = np.transpose(x,[0,2,1])
  40. x = torch.Tensor(x)
  41. x = Variable(x)
  42. x = self.conv(x)
  43. x = x.view(-1,256*596)
  44. x = self.f1(x)
  45. return self.f2(x)
train.py:
  1. # coding: utf-8
  2. from __future__ import print_function
  3. import torch
  4. from torch import nn
  5. from torch import optim
  6. from torch.autograd import Variable
  7. import os
  8. import numpy as np
  9. from model import TextRNN,TextCNN
  10. from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
  11. base_dir = 'cnews'
  12. train_dir = os.path.join(base_dir, 'cnews.train.txt')
  13. test_dir = os.path.join(base_dir, 'cnews.test.txt')
  14. val_dir = os.path.join(base_dir, 'cnews.val.txt')
  15. vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
  16. def train():
  17. x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#获取训练数据每个字的id和对应标签的oe-hot形式
  18. x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600)
  19. #使用LSTM或者CNN
  20. model = TextRNN()
  21. # model = TextCNN()
  22. #选择损失函数
  23. Loss = nn.MultiLabelSoftMarginLoss()
  24. # Loss = nn.BCELoss()
  25. # Loss = nn.MSELoss()
  26. optimizer = optim.Adam(model.parameters(),lr=0.001)
  27. best_val_acc = 0
  28. for epoch in range(1000):
  29. batch_train = batch_iter(x_train, y_train,100)
  30. for x_batch, y_batch in batch_train:
  31. x = np.array(x_batch)
  32. y = np.array(y_batch)
  33. x = torch.LongTensor(x)
  34. y = torch.Tensor(y)
  35. # y = torch.LongTensor(y)
  36. x = Variable(x)
  37. y = Variable(y)
  38. out = model(x)
  39. loss = Loss(out,y)
  40. optimizer.zero_grad()
  41. loss.backward()
  42. optimizer.step()
  43. accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy())
  44. #对模型进行验证
  45. if (epoch+1)%20 == 0:
  46. batch_val = batch_iter(x_val, y_val, 100)
  47. for x_batch, y_batch in batch_train:
  48. x = np.array(x_batch)
  49. y = np.array(y_batch)
  50. x = torch.LongTensor(x)
  51. y = torch.Tensor(y)
  52. # y = torch.LongTensor(y)
  53. x = Variable(x)
  54. y = Variable(y)
  55. out = model(x)
  56. loss = Loss(out, y)
  57. optimizer.zero_grad()
  58. loss.backward()
  59. optimizer.step()
  60. accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy())
  61. if accracy > best_val_acc:
  62. torch.save(model.state_dict(),'model_params.pkl')
  63. best_val_acc = accracy
  64. print(accracy)
  65. if __name__ == '__main__':
  66. #获取文本的类别及其对应id的字典
  67. categories, cat_to_id = read_category()
  68. #获取训练文本中所有出现过的字及其所对应的id
  69. words, word_to_id = read_vocab(vocab_dir)
  70. #获取字数
  71. vocab_size = len(words)
  72. train()
test.py:
  1. # coding: utf-8
  2. from __future__ import print_function
  3. import os
  4. import tensorflow.contrib.keras as kr
  5. import torch
  6. from torch import nn
  7. from cnews_loader import read_category, read_vocab
  8. from model import TextRNN
  9. from torch.autograd import Variable
  10. import numpy as np
  11. try:
  12. bool(type(unicode))
  13. except NameError:
  14. unicode = str
  15. base_dir = 'cnews'
  16. vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
  17. class TextCNN(nn.Module):
  18. def __init__(self):
  19. super(TextCNN, self).__init__()
  20. self.embedding = nn.Embedding(5000,64)
  21. self.conv = nn.Conv1d(64,256,5)
  22. self.f1 = nn.Sequential(nn.Linear(152576, 128),
  23. nn.ReLU())
  24. self.f2 = nn.Sequential(nn.Linear(128, 10),
  25. nn.Softmax())
  26. def forward(self, x):
  27. x = self.embedding(x)
  28. x = x.detach().numpy()
  29. x = np.transpose(x,[0,2,1])
  30. x = torch.Tensor(x)
  31. x = Variable(x)
  32. x = self.conv(x)
  33. x = x.view(-1,152576)
  34. x = self.f1(x)
  35. return self.f2(x)
  36. class CnnModel:
  37. def __init__(self):
  38. self.categories, self.cat_to_id = read_category()
  39. self.words, self.word_to_id = read_vocab(vocab_dir)
  40. self.model = TextCNN()
  41. self.model.load_state_dict(torch.load('model_params.pkl'))
  42. def predict(self, message):
  43. # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
  44. content = unicode(message)
  45. data = [self.word_to_id[x] for x in content if x in self.word_to_id]
  46. data = kr.preprocessing.sequence.pad_sequences([data],600)
  47. data = torch.LongTensor(data)
  48. y_pred_cls = self.model(data)
  49. class_index = torch.argmax(y_pred_cls[0]).item()
  50. return self.categories[class_index]
  51. class RnnModel:
  52. def __init__(self):
  53. self.categories, self.cat_to_id = read_category()
  54. self.words, self.word_to_id = read_vocab(vocab_dir)
  55. self.model = TextRNN()
  56. self.model.load_state_dict(torch.load('model_rnn_params.pkl'))
  57. def predict(self, message):
  58. # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
  59. content = unicode(message)
  60. data = [self.word_to_id[x] for x in content if x in self.word_to_id]
  61. data = kr.preprocessing.sequence.pad_sequences([data], 600)
  62. data = torch.LongTensor(data)
  63. y_pred_cls = self.model(data)
  64. class_index = torch.argmax(y_pred_cls[0]).item()
  65. return self.categories[class_index]
  66. if __name__ == '__main__':
  67. model = CnnModel()
  68. # model = RnnModel()
  69. test_demo = ['湖人助教力助科比恢复手感 他也是阿泰的精神导师新浪体育讯记者戴高乐报道  上赛季,科比的右手食指遭遇重创,他的投篮手感也因此大受影响。不过很快科比就调整了自己的投篮手型,并通过这一方式让自己的投篮命中率回升。而在这科比背后,有一位特别助教对科比帮助很大,他就是查克·珀森。珀森上赛季担任湖人的特别助教,除了帮助科比调整投篮手型之外,他的另一个重要任务就是担任阿泰的精神导师。来到湖人队之后,阿泰收敛起了暴躁的脾气,成为湖人夺冠路上不可或缺的一员,珀森的“心灵按摩”功不可没。经历了上赛季的成功之后,珀森本赛季被“升职”成为湖人队的全职助教,每场比赛,他都会坐在球场边,帮助禅师杰克逊一起指挥湖人球员在场上拼杀。对于珀森的工作,禅师非常欣赏,“查克非常善于分析问题,”菲尔·杰克逊说,“他总是在寻找问题的答案,同时也在找造成这一问题的原因,这是我们都非常乐于看到的。我会在平时把防守中出现的一些问题交给他,然后他会通过组织球员练习找到解决的办法。他在球员时代曾是一名很好的外线投手,不过现在他与内线球员的配合也相当不错。',
  70. '弗老大被裁美国媒体看热闹“特权”在中国像蠢蛋弗老大要走了。虽然他只在首钢男篮效力了13天,而且表现毫无亮点,大大地让球迷和俱乐部失望了,但就像中国人常说的“好聚好散”,队友还是友好地与他告别,俱乐部与他和平分手,球迷还请他留下了在北京的最后一次签名。相比之下,弗老大的同胞美国人却没那么“宽容”。他们嘲讽这位NBA前巨星的英雄迟暮,批评他在CBA的业余表现,还惊讶于中国人的“大方”。今天,北京首钢俱乐部将与弗朗西斯继续商讨解约一事。从昨日的进展来看,双方可以做到“买卖不成人意在”,但回到美国后,恐怕等待弗朗西斯的就没有这么轻松的环境了。进展@北京昨日与队友告别  最后一次为球迷签名弗朗西斯在13天里为首钢队打了4场比赛,3场的得分为0,只有一场得了2分。昨天是他来到北京的第14天,虽然他与首钢还未正式解约,但双方都明白“缘分已尽”。下午,弗朗西斯来到首钢俱乐部与队友们告别。弗朗西斯走到队友身边,依次与他们握手拥抱。“你们都对我很好,安排的条件也很好,我很喜欢这支球队,想融入你们,但我现在真的很不适应。希望你们']
  71. for i in test_demo:
  72. print(i,":",model.predict(i))
cnews_loader.py:
  1. # coding: utf-8
  2. import sys
  3. from collections import Counter
  4. import numpy as np
  5. import tensorflow.contrib.keras as kr
  6. if sys.version_info[0] > 2:
  7. is_py3 = True
  8. else:
  9. reload(sys)
  10. sys.setdefaultencoding("utf-8")
  11. is_py3 = False
  12. def native_word(word, encoding='utf-8'):
  13. """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
  14. if not is_py3:
  15. return word.encode(encoding)
  16. else:
  17. return word
  18. def native_content(content):
  19. if not is_py3:
  20. return content.decode('utf-8')
  21. else:
  22. return content
  23. def open_file(filename, mode='r'):
  24. """
  25. 常用文件操作,可在python2和python3间切换.
  26. mode: 'r' or 'w' for read or write
  27. """
  28. if is_py3:
  29. return open(filename, mode, encoding='utf-8', errors='ignore')
  30. else:
  31. return open(filename, mode)
  32. def read_file(filename):
  33. """读取文件数据"""
  34. contents, labels = [], []
  35. with open_file(filename) as f:
  36. for line in f:
  37. try:
  38. label, content = line.strip().split('\t')
  39. if content:
  40. contents.append(list(native_content(content)))
  41. labels.append(native_content(label))
  42. except:
  43. pass
  44. return contents, labels
  45. def build_vocab(train_dir, vocab_dir, vocab_size=5000):
  46. """根据训练集构建词汇表,存储"""
  47. data_train, _ = read_file(train_dir)
  48. all_data = []
  49. for content in data_train:
  50. all_data.extend(content)
  51. counter = Counter(all_data)
  52. count_pairs = counter.most_common(vocab_size - 1)
  53. words, _ = list(zip(*count_pairs))
  54. # 添加一个 <PAD> 来将所有文本pad为同一长度
  55. words = ['<PAD>'] + list(words)
  56. open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
  57. def read_vocab(vocab_dir):
  58. """读取词汇表"""
  59. # words = open_file(vocab_dir).read().strip().split('\n')
  60. with open_file(vocab_dir) as fp:
  61. # 如果是py2 则每个值都转化为unicode
  62. words = [native_content(_.strip()) for _ in fp.readlines()]
  63. word_to_id = dict(zip(words, range(len(words))))
  64. return words, word_to_id
  65. def read_category():
  66. """读取分类目录,固定"""
  67. categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
  68. categories = [native_content(x) for x in categories]
  69. cat_to_id = dict(zip(categories, range(len(categories))))
  70. return categories, cat_to_id
  71. def to_words(content, words):
  72. """将id表示的内容转换为文字"""
  73. return ''.join(words[x] for x in content)
  74. def process_file(filename, word_to_id, cat_to_id, max_length=600):
  75. """将文件转换为id表示"""
  76. contents, labels = read_file(filename)#读取训练数据的每一句话及其所对应的类别
  77. data_id, label_id = [], []
  78. for i in range(len(contents)):
  79. data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#将每句话id化
  80. label_id.append(cat_to_id[labels[i]])#每句话对应的类别的id
  81. #
  82. # # 使用keras提供的pad_sequences来将文本pad为固定长度
  83. x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
  84. y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示
  85. #
  86. return x_pad, y_pad
  87. def batch_iter(x, y, batch_size=64):
  88. """生成批次数据"""
  89. data_len = len(x)
  90. num_batch = int((data_len - 1) / batch_size) + 1
  91. indices = np.random.permutation(np.arange(data_len))
  92. x_shuffle = x[indices]
  93. y_shuffle = y[indices]
  94. for i in range(num_batch):
  95. start_id = i * batch_size
  96. end_id = min((i + 1) * batch_size, data_len)
  97. yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

 

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号