Bag of Tricks for Efficient Text Classification(Fasttext)

三、Model architerture








     1)  Fasttext提取的是句子特征,CBOW提取的是上下文特征 



      1)  当类别非常多的时候,最后的softmax速度比较慢(因为要构造词表大小的数据)

      2)  使用的是词袋模型,没有词序信息


       1) 层次softmax

             和word2vec中的层次softmax一样,可以减少参数由原来的H*V -> H*log2V (V表示词表大小)

        2) 添加使用n-gram特征








五、Discussion and conclusion
















  1. # ****** 数据预处理 *****
  2. # 主要包括几个部分-数据集加载、读取标签和数据、创建word2id、将数据转化为id, 本次实验还是使用AG数据集合,数据集下载位置 AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
  3. # encoding = 'utf-8'
  4. from torch.utils import data
  5. import os
  6. import csv
  7. import nltk
  8. import numpy as np
  9. # 数据集加载
  10. f = open("./data/AG/train.csv")
  11. rows = csv.reader(f,delimiter=',',quotechar='"')
  12. rows = list(rows)
  13. rows[1:5]
  14. [['3',
  15. 'Carlyle Looks Toward Commercial Aerospace (Reuters)',
  16. 'Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],
  17. ['3',
  18. "Oil and Economy Cloud Stocks' Outlook (Reuters)",
  19. 'Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.'],
  20. ['3',
  21. 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)',
  22. 'Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.'],
  23. ['3',
  24. 'Oil prices soar to all-time record, posing new menace to US economy (AFP)',
  25. 'AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.']]
  26. # 读取标签和数据
  27. n_gram,lowercase,label,datas = 2,True,[],[]
  28. for row in rows:
  29. label.append(int(row[0])-1)
  30. txt = " ".join(row[1:])
  31. if lowercase:
  32. txt = txt.lower()
  33. txt = nltk.word_tokenize(txt) #将句子转化为词
  34. new_txt = []
  35. for i in range(len(txt)):
  36. for j in range(n_gram):
  37. if j<=i:
  38. new_txt.append(" ".join(txt[i-j:i+1]))
  39. datas.append(new_txt)
  40. # word2id
  41. min_count,word_freq = 3,{}
  42. for data in datas:
  43. for word in data:
  44. if word not in word_freq:
  45. word_freq[word] = 1
  46. else:
  47. word_freq[word] += 1
  48. # 首先构建uni-gram,不需要hash
  49. word2id = {"<pad>":0,"<unk>":1}
  50. for word in word_freq:
  51. if word_freq[word] < min_count or " " in word:
  52. continue
  53. word2id[word] = len(word2id)
  54. uniwords_num = len(word2id)
  55. # 构建2-gram以上的词,需要hash
  56. for word in word_freq:
  57. if word_freq[word] < min_count or " " not in word:
  58. continue
  59. word2id[word] = len(word2id)
  60. # 将文本中的词都转化为id,设置句子长度为100,词表最大限制为1w
  61. max_length = 100
  62. for i,data in enumerate(datas):
  63. for j,word in enumerate(data):
  64. if " " not in word:
  65. datas[i][j] = word2id.get(word,1)
  66. else:
  67. datas[i][j] = word2id.get(word,1)%10000 + uniwords_num
  68. datas[i] = datas[i][0:max_length] + [0]*(max_length - len(datas[i]))


  1. # """ 模型代码 """
  2. # encoding='utf-8'
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. class Fasttext(nn.Module):
  7. def __init__(self,vocab_size,embedding_size,max_length,label_num):
  8. super(Fasttext,self).__init__()
  9. self.embedding = nn.Embedding(vocab_size,embedding_size)
  10. self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)
  11. self.fc = nn.Linear(embedding_size,label_num)
  12. def forward(self,x):
  13. x = x.long()
  14. out = self.embedding(x) # batch_size * length * embedding_size
  15. out = out.transpose(1,2).contiguous() # batch_size * embedding_size * length
  16. out = self.avg_pool(out).squeeze() # batch_size * embedding_size
  17. out = self.fc(out) # batch_size * label_num
  18. return out
  19. fasttext = Fasttext(vocab_size=1000,embedding_size=10,max_length=100,label_num=4)
  20. test = torch.zeros([64,100]).long()
  21. out = fasttext(test)
  22. """ 查看网络参数 """
  23. from torchsummary import summary
  24. summary(fasttext,input_size=(100,))
  1. """ 模型训练 """
  2. # encoding = 'utf-8'
  3. import torch
  4. import torch.autograd as autograd
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from model import Fasttext
  8. from data import AG_Data
  9. import numpy as np
  10. from tqdm import tqdm
  11. import config as argumentparser
  12. config = argumentparser.ArgumentParser()
  13. """ 加载数据集 """
  14. training_set = AG_Data("/AG/train.csv",
  15. min_count = config.min_count,
  16. max_length=config.max_length,
  17. n_gram=config.n_gram)
  18. train_iter = torch.utils.data.DataLoader(dataset=training_set,
  19. batch_size=config.batch_size,
  20. shuffle=True,
  21. num_workers=0)
  22. test_set = AG_Data(data_path="/AG/test.csv",
  23. min_count=config.min_count,
  24. max_length = config.max_length,
  25. n_gram = config.n_gram,
  26. word2id = training_set.word2id,
  27. uniwords_num=training_set.uniwords_num)
  28. test_iter = torch.utils.data.DataLoader(dataset=test_set,
  29. batch_size=config.batch_size,
  30. shuffle = True,
  31. num_workers=0)
  32. """ 构建模型 """
  33. model = Fasttext(vocab_size=training_set.uniwords_num+100000,
  34. embedding_size=config.embed_size,
  35. max_length=config.max_length,
  36. label_num=config.label_num)
  37. if config.cuda and torch.cuda.is_available():
  38. model.cuda()
  39. criterion = nn.CrossEntropyLoss()
  40. optimizer = optim.Adam(model.parameters(),lr=config.learning_rate)
  41. loss = -1
  42. def get_test_result(data_iter,data_set):
  43. mode.eval()
  44. true_sample_num = 0
  45. for data,label in data_iter:
  46. if config.cuda and torch.cuda.is_available():
  47. data = data.cuda()
  48. label = label.cuda()
  49. else:
  50. data = torch.autograd.Variable(data).long()
  51. out = model(data)
  52. true_sample_num += np.sum((torch.argmax(out,1)==label.long()).cpu().numpy())
  53. acc = true_sample_num/data_set.__len__()
  54. return acc
  55. for epoch in range(1):
  56. model.train()
  57. process_bar = tqdm(train_iter)
  58. for data,label in process_bar:
  59. if config.cuda and torch.cuda.is_available():
  60. data = data.cuda()
  61. label = label.cuda()
  62. else:
  63. data = torch.autograd.Variable(data).long()
  64. label = torch.autograd.Variable(label).squeeze()
  65. out = model(data)
  66. loss_now = criterion(out,autograd.Variable(label.long()))
  67. if loss == -1:
  68. loss = loss_now.data.item()
  69. else:
  70. loss = 0.95*loss + 0.05*loss_now.data.item()
  71. process_bar.set_postfix(loss=loss_now.data.item())
  72. process_bar.update()
  73. optimizer.zero_grad()
  74. loss_now.backward()
  75. optimizer.step()
  76. test_acc = get_test_result(test_iter,test_set)
  77. print("The test acc is: %.5f" % test_acc)


