赞
踩
该数据集包含2,507篇研究论文标题,并已手动分类为5个类别(即会议)。
- import torch
- from tqdm.notebook import tqdm
-
- from transformers import BertTokenizer
- from torch.utils.data import TensorDataset
-
- from transformers import BertForSequenceClassification
-
- df = pd.read_csv('data/title_conference.csv')
- df.head()
df['Conference'].value_counts()
您可能已经注意到我们的类别不平衡,我们将在稍后解决这个问题。
- possible_labels = df.Conference.unique()
-
- label_dict = {}
- for index, possible_label in enumerate(possible_labels):
- label_dict[possible_label] = index
- label_dict
df['label'] = df.Conference.replace(label_dict)
由于标签不平衡,我们以分层的方式划分数据集,使用这个作为类别标签。
在划分后,我们的标签分布将如下所示。
- from sklearn.model_selection import train_test_split
-
- X_train, X_val, y_train, y_val = train_test_split(df.index.values,
- df.label.values,
- test_size=0.15,
- random_state=42,
- stratify=df.label.values)
-
- df['data_type'] = ['not_set']*df.shape[0]
-
- df.loc[X_train, 'data_type'] = 'train'
- df.loc[X_val, 'data_type'] = 'val'
-
- df.groupby(['Conference', 'label', 'data_type']).count()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。