当前位置:   article > 正文

NLP 算法实战项目:使用 BERT 进行文本多分类_bert多文本分类

bert多文本分类

数据

该数据集包含2,507篇研究论文标题,并已手动分类为5个类别(即会议)。

探索与预处理

  1. import torch
  2. from tqdm.notebook import tqdm
  3. from transformers import BertTokenizer
  4. from torch.utils.data import TensorDataset
  5. from transformers import BertForSequenceClassification
  6. df = pd.read_csv('data/title_conference.csv')
  7. df.head()

图片

df['Conference'].value_counts()

图片

您可能已经注意到我们的类别不平衡,我们将在稍后解决这个问题。

对标签进行编码

  1. possible_labels = df.Conference.unique()
  2. label_dict = {}
  3. for index, possible_label in enumerate(possible_labels):
  4.     label_dict[possible_label] = index
  5. label_dict

图片

df['label'= df.Conference.replace(label_dict)

训练和验证集划分

由于标签不平衡,我们以分层的方式划分数据集,使用这个作为类别标签。

在划分后,我们的标签分布将如下所示。

  1. from sklearn.model_selection import train_test_split
  2. X_train, X_val, y_train, y_val = train_test_split(df.index.values
  3.                                                   df.label.values
  4.                                                   test_size=0.15
  5.                                                   random_state=42
  6.                                                   stratify=df.label.values)
  7. df['data_type'= ['not_set']*df.shape[0]
  8. df.loc[X_train, 'data_type'= 'train'
  9. df.loc[X_val, 'data_type'= 'val'
  10. df.groupby(['Conference''label''data_type']).count()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/664011
推荐阅读
相关标签
  

闽ICP备14008679号