当前位置:   article > 正文

NLP-LSTM文本分类模型实战_lstm文本分类实战

lstm文本分类实战
  1. #utils_fasttext.py
  2. # coding: UTF-8
  3. import os
  4. import torch
  5. import numpy as np
  6. import pickle as pkl
  7. from tqdm import tqdm
  8. import time
  9. from datetime import timedelta
  10. MAX_VOCAB_SIZE = 10000
  11. UNK, PAD = '<UNK>', '<PAD>'
  12. def build_vocab(file_path, tokenizer, max_size, min_freq):
  13. vocab_dic = {}
  14. with open(file_path, 'r', encoding='UTF-8') as f:
  15. for line in tqdm(f):
  16. lin = line.strip()
  17. if not lin:
  18. continue
  19. content = lin.split('\t')[0]
  20. for word in tokenizer(content):
  21. vocab_dic[word] = vocab_dic.get(word, 0) + 1
  22. vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
  23. vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
  24. vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
  25. return vocab_dic
  26. def build_dataset(config, ues_word):
  27. if ues_word:
  28. tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
  29. else:
  30. tokenizer = lambda x: [y for y in x] # char-level
  31. if os.path.exists(config.vocab_path):
  32. vocab = pkl.load(open(config.vocab_path, 'rb'))
  33. else:
  34. vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
  35. pkl.dump(vocab, open(config.vocab_path, 'wb'))
  36. print(f"Vocab size: {len(vocab)}")
  37. def biGramHash(sequence, t, buckets):
  38. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  39. return (t1 * 14918087) % buckets
  40. def triGramHash(sequence, t, buckets):
  41. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  42. t2 = sequence[t - 2] if t - 2 >= 0 else 0
  43. return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
  44. def load_dataset(path, pad_size=32):
  45. contents = []
  46. with open(path, 'r', encoding='UTF-8') as f:
  47. for line in tqdm(f):
  48. lin = line.strip()
  49. if not lin:
  50. continue
  51. content, label = lin.split('\t')
  52. words_line = []
  53. token = tokenizer(content)
  54. seq_len = len(token)
  55. if pad_size:
  56. if len(token) < pad_size:
  57. token.extend([vocab.get(PAD)] * (pad_size - len(token)))
  58. else:
  59. token = token[:pad_size]
  60. seq_len = pad_size
  61. # word to id
  62. for word in token:
  63. words_line.append(vocab.get(word, vocab.get(UNK)))
  64. # fasttext ngram
  65. buckets = config.n_gram_vocab
  66. bigram = []
  67. trigram = []
  68. # ------ngram------
  69. for i in range(pad_size):
  70. bigram.append(biGramHash(words_line, i, buckets))
  71. trigram.append(triGramHash(words_line, i, buckets))
  72. # -----------------
  73. contents.append((words_line, int(label), seq_len, bigram, trigram))
  74. return contents # [([...], 0), ([...], 1), ...]
  75. train = load_dataset(config.train_path, config.pad_size)
  76. dev = load_dataset(config.dev_path, config.pad_size)
  77. test = load_dataset(config.test_path, config.pad_size)
  78. return vocab, train, dev, test
  79. class DatasetIterater(object):
  80. def __init__(self, batches, batch_size, device):
  81. self.batch_size = batch_size
  82. self.batches = batches
  83. self.n_batches = len(batches) // batch_size
  84. self.residue = False # 记录batch数量是否为整数
  85. if len(batches) % self.n_batches != 0:
  86. self.residue = True
  87. self.index = 0
  88. self.device = device
  89. def _to_tensor(self, datas):
  90. # xx = [xxx[2] for xxx in datas]
  91. # indexx = np.argsort(xx)[::-1]
  92. # datas = np.array(datas)[indexx]
  93. x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
  94. y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
  95. bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device)
  96. trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device)
  97. # pad前的长度(超过pad_size的设为pad_size)
  98. seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
  99. return (x, seq_len, bigram, trigram), y
  100. def __next__(self):
  101. if self.residue and self.index == self.n_batches:
  102. batches = self.batches[self.index * self.batch_size: len(self.batches)]
  103. self.index += 1
  104. batches = self._to_tensor(batches)
  105. return batches
  106. elif self.index > self.n_batches:
  107. self.index = 0
  108. raise StopIteration
  109. else:
  110. batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
  111. self.index += 1
  112. batches = self._to_tensor(batches)
  113. return batches
  114. def __iter__(self):
  115. return self
  116. def __len__(self):
  117. if self.residue:
  118. return self.n_batches + 1
  119. else:
  120. return self.n_batches
  121. def build_iterator(dataset, config):
  122. iter = DatasetIterater(dataset, config.batch_size, config.device)
  123. return iter
  124. def get_time_dif(start_time):
  125. """获取已使用时间"""
  126. end_time = time.time()
  127. time_dif = end_time - start_time
  128. return timedelta(seconds=int(round(time_dif)))
  129. if __name__ == "__main__":
  130. '''提取预训练词向量'''
  131. vocab_dir = "./THUCNews/data/vocab.pkl"
  132. pretrain_dir = "./THUCNews/data/sgns.sogou.char"
  133. emb_dim = 300
  134. filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou"
  135. word_to_id = pkl.load(open(vocab_dir, 'rb'))
  136. embeddings = np.random.rand(len(word_to_id), emb_dim)
  137. f = open(pretrain_dir, "r", encoding='UTF-8')
  138. for i, line in enumerate(f.readlines()):
  139. # if i == 0: # 若第一行是标题,则跳过
  140. # continue
  141. lin = line.strip().split(" ")
  142. if lin[0] in word_to_id:
  143. idx = word_to_id[lin[0]]
  144. emb = [float(x) for x in lin[1:301]]
  145. embeddings[idx] = np.asarray(emb, dtype='float32')
  146. f.close()
  147. np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)

  1. # utils.py
  2. # coding: UTF-8
  3. import os
  4. import torch
  5. import numpy as np
  6. import pickle as pkl
  7. from tqdm import tqdm
  8. import time
  9. from datetime import timedelta
  10. MAX_VOCAB_SIZE = 10000 # 词表长度限制
  11. UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号
  12. def build_vocab(file_path, tokenizer, max_size, min_freq):
  13. vocab_dic = {}
  14. with open(file_path, 'r', encoding='UTF-8') as f:
  15. for line in tqdm(f):
  16. lin = line.strip()
  17. if not lin:
  18. continue
  19. content = lin.split('\t')[0]
  20. for word in tokenizer(content):
  21. vocab_dic[word] = vocab_dic.get(word, 0) + 1
  22. vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
  23. vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
  24. vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
  25. return vocab_dic
  26. def build_dataset(config, ues_word):
  27. if ues_word:
  28. tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
  29. else:
  30. tokenizer = lambda x: [y for y in x] # char-level
  31. if os.path.exists(config.vocab_path):
  32. vocab = pkl.load(open(config.vocab_path, 'rb'))
  33. else:
  34. vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
  35. pkl.dump(vocab, open(config.vocab_path, 'wb'))
  36. print(f"Vocab size: {len(vocab)}")
  37. def load_dataset(path, pad_size=32):
  38. contents = []
  39. with open(path, 'r', encoding='UTF-8') as f:
  40. for line in tqdm(f):
  41. lin = line.strip()
  42. if not lin:
  43. continue
  44. content, label = lin.split('\t')
  45. words_line = []
  46. token = tokenizer(content)
  47. seq_len = len(token)
  48. if pad_size:
  49. if len(token) < pad_size:
  50. token.extend([vocab.get(PAD)] * (pad_size - len(token)))
  51. else:
  52. token = token[:pad_size]
  53. seq_len = pad_size
  54. # word to id
  55. for word in token:
  56. words_line.append(vocab.get(word, vocab.get(UNK)))
  57. contents.append((words_line, int(label), seq_len))
  58. return contents # [([...], 0), ([...], 1), ...]
  59. train = load_dataset(config.train_path, config.pad_size)
  60. dev = load_dataset(config.dev_path, config.pad_size)
  61. test = load_dataset(config.test_path, config.pad_size)
  62. return vocab, train, dev, test
  63. class DatasetIterater(object):
  64. def __init__(self, batches, batch_size, device):
  65. self.batch_size = batch_size
  66. self.batches = batches
  67. self.n_batches = len(batches) // batch_size
  68. self.residue = False # 记录batch数量是否为整数
  69. if len(batches) % self.n_batches != 0:
  70. self.residue = True
  71. self.index = 0
  72. self.device = device
  73. def _to_tensor(self, datas):
  74. x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
  75. y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
  76. # pad前的长度(超过pad_size的设为pad_size)
  77. seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
  78. return (x, seq_len), y
  79. def __next__(self):
  80. if self.residue and self.index == self.n_batches:
  81. batches = self.batches[self.index * self.batch_size: len(self.batches)]
  82. self.index += 1
  83. batches = self._to_tensor(batches)
  84. return batches
  85. elif self.index > self.n_batches:
  86. self.index = 0
  87. raise StopIteration
  88. else:
  89. batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
  90. self.index += 1
  91. batches = self._to_tensor(batches)
  92. return batches
  93. def __iter__(self):
  94. return self
  95. def __len__(self):
  96. if self.residue:
  97. return self.n_batches + 1
  98. else:
  99. return self.n_batches
  100. def build_iterator(dataset, config):
  101. iter = DatasetIterater(dataset, config.batch_size, config.device)
  102. return iter
  103. def get_time_dif(start_time):
  104. """获取已使用时间"""
  105. end_time = time.time()
  106. time_dif = end_time - start_time
  107. return timedelta(seconds=int(round(time_dif)))
  108. if __name__ == "__main__":
  109. '''提取预训练词向量'''
  110. # 下面的目录、文件名按需更改。
  111. train_dir = "./THUCNews/data/train.txt"
  112. vocab_dir = "./THUCNews/data/vocab.pkl"
  113. pretrain_dir = "./THUCNews/data/sgns.sogou.char"
  114. emb_dim = 300
  115. filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews"
  116. if os.path.exists(vocab_dir):
  117. word_to_id = pkl.load(open(vocab_dir, 'rb'))
  118. else:
  119. # tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开)
  120. tokenizer = lambda x: [y for y in x] # 以字为单位构建词表
  121. word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
  122. pkl.dump(word_to_id, open(vocab_dir, 'wb'))
  123. embeddings = np.random.rand(len(word_to_id), emb_dim)
  124. f = open(pretrain_dir, "r", encoding='UTF-8')
  125. for i, line in enumerate(f.readlines()):
  126. # if i == 0: # 若第一行是标题,则跳过
  127. # continue
  128. lin = line.strip().split(" ")
  129. if lin[0] in word_to_id:
  130. idx = word_to_id[lin[0]]
  131. emb = [float(x) for x in lin[1:301]]
  132. embeddings[idx] = np.asarray(emb, dtype='float32')
  133. f.close()
  134. np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)

  1. # train_eval.py
  2. # coding: UTF-8
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sklearn import metrics
  8. import time
  9. from utils import get_time_dif
  10. from tensorboardX import SummaryWriter
  11. # 权重初始化,默认xavier
  12. def init_network(model, method='xavier', exclude='embedding', seed=123):
  13. for name, w in model.named_parameters():
  14. if exclude not in name:
  15. if 'weight' in name:
  16. if method == 'xavier':
  17. nn.init.xavier_normal_(w)
  18. elif method == 'kaiming':
  19. nn.init.kaiming_normal_(w)
  20. else:
  21. nn.init.normal_(w)
  22. elif 'bias' in name:
  23. nn.init.constant_(w, 0)
  24. else:
  25. pass
  26. def train(config, model, train_iter, dev_iter, test_iter,writer):
  27. start_time = time.time()
  28. model.train()
  29. optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
  30. # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
  31. # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
  32. total_batch = 0 # 记录进行到多少batch
  33. dev_best_loss = float('inf')
  34. last_improve = 0 # 记录上次验证集loss下降的batch数
  35. flag = False # 记录是否很久没有效果提升
  36. #writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
  37. for epoch in range(config.num_epochs):
  38. print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
  39. # scheduler.step() # 学习率衰减
  40. for i, (trains, labels) in enumerate(train_iter):
  41. #print (trains[0].shape)
  42. outputs = model(trains)
  43. model.zero_grad()
  44. loss = F.cross_entropy(outputs, labels)
  45. loss.backward()
  46. optimizer.step()
  47. if total_batch % 100 == 0:
  48. # 每多少轮输出在训练集和验证集上的效果
  49. true = labels.data.cpu()
  50. predic = torch.max(outputs.data, 1)[1].cpu()
  51. train_acc = metrics.accuracy_score(true, predic)
  52. dev_acc, dev_loss = evaluate(config, model, dev_iter)
  53. if dev_loss < dev_best_loss:
  54. dev_best_loss = dev_loss
  55. torch.save(model.state_dict(), config.save_path)
  56. improve = '*'
  57. last_improve = total_batch
  58. else:
  59. improve = ''
  60. time_dif = get_time_dif(start_time)
  61. msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
  62. print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
  63. writer.add_scalar("loss/train", loss.item(), total_batch)
  64. writer.add_scalar("loss/dev", dev_loss, total_batch)
  65. writer.add_scalar("acc/train", train_acc, total_batch)
  66. writer.add_scalar("acc/dev", dev_acc, total_batch)
  67. model.train()
  68. total_batch += 1
  69. if total_batch - last_improve > config.require_improvement:
  70. # 验证集loss超过1000batch没下降,结束训练
  71. print("No optimization for a long time, auto-stopping...")
  72. flag = True
  73. break
  74. if flag:
  75. break
  76. writer.close()
  77. test(config, model, test_iter)
  78. def test(config, model, test_iter):
  79. # test
  80. model.load_state_dict(torch.load(config.save_path))
  81. model.eval()
  82. start_time = time.time()
  83. test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
  84. msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
  85. print(msg.format(test_loss, test_acc))
  86. print("Precision, Recall and F1-Score...")
  87. print(test_report)
  88. print("Confusion Matrix...")
  89. print(test_confusion)
  90. time_dif = get_time_dif(start_time)
  91. print("Time usage:", time_dif)
  92. def evaluate(config, model, data_iter, test=False):
  93. model.eval()
  94. loss_total = 0
  95. predict_all = np.array([], dtype=int)
  96. labels_all = np.array([], dtype=int)
  97. with torch.no_grad():
  98. for texts, labels in data_iter:
  99. outputs = model(texts)
  100. loss = F.cross_entropy(outputs, labels)
  101. loss_total += loss
  102. labels = labels.data.cpu().numpy()
  103. predic = torch.max(outputs.data, 1)[1].cpu().numpy()
  104. labels_all = np.append(labels_all, labels)
  105. predict_all = np.append(predict_all, predic)
  106. acc = metrics.accuracy_score(labels_all, predict_all)
  107. if test:
  108. report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
  109. confusion = metrics.confusion_matrix(labels_all, predict_all)
  110. return acc, loss_total / len(data_iter), report, confusion
  111. return acc, loss_total / len(data_iter)

  1. # run.py
  2. import time
  3. import torch
  4. import numpy as np
  5. from train_eval import train, init_network
  6. from importlib import import_module
  7. import argparse
  8. from tensorboardX import SummaryWriter
  9. parser = argparse.ArgumentParser(description='Chinese Text Classification')
  10. parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
  11. parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
  12. parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
  13. args = parser.parse_args()
  14. if __name__ == '__main__':
  15. dataset = 'THUCNews' # 数据集
  16. # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
  17. embedding = 'embedding_SougouNews.npz'
  18. if args.embedding == 'random':
  19. embedding = 'random'
  20. model_name = args.model #TextCNN, TextRNN,
  21. if model_name == 'FastText':
  22. from utils_fasttext import build_dataset, build_iterator, get_time_dif
  23. embedding = 'random'
  24. else:
  25. from utils import build_dataset, build_iterator, get_time_dif
  26. x = import_module('models.' + model_name)
  27. config = x.Config(dataset, embedding)
  28. np.random.seed(1)
  29. torch.manual_seed(1)
  30. torch.cuda.manual_seed_all(1)
  31. torch.backends.cudnn.deterministic = True # 保证每次结果一样
  32. start_time = time.time()
  33. print("Loading data...")
  34. vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
  35. train_iter = build_iterator(train_data, config)
  36. dev_iter = build_iterator(dev_data, config)
  37. test_iter = build_iterator(test_data, config)
  38. time_dif = get_time_dif(start_time)
  39. print("Time usage:", time_dif)
  40. # train
  41. config.n_vocab = len(vocab)
  42. model = x.Model(config).to(config.device)
  43. writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
  44. if model_name != 'Transformer':
  45. init_network(model)
  46. print(model.parameters)
  47. train(config, model, train_iter, dev_iter, test_iter,writer)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/1015684
推荐阅读
相关标签
  

闽ICP备14008679号