当前位置:   article > 正文

transformers训练自己的数据集实战_distilbert-base-uncased

distilbert-base-uncased
  1. """
  2. 使用IMDb评论进行序列分类
  3. """
  4. #先下载数据
  5. # wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
  6. # tar -xf aclImdb_v1.tar.gz
  7. #整理文件
  8. from pathlib import Path
  9. def read_imdb_split(split_dir):
  10. split_dir = Path(split_dir)
  11. texts = []
  12. labels = []
  13. for label_dir in ["pos", "neg"]:
  14. for text_file in (split_dir/label_dir).iterdir():
  15. texts.append(text_file.read_text())
  16. labels.append(0 if label_dir is "neg" else 1)
  17. return texts, labels
  18. train_texts, train_labels = read_imdb_split('aclImdb/train')
  19. test_texts, test_labels = read_imdb_split('aclImdb/test')
  20. #划分训练集和测试集
  21. from sklearn.model_selection import train_test_split
  22. train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)
  23. #好了,我们已经阅读了数据集。现在让我们讨论令牌化。我们最终将使用预训练的DistilBert训练分类器,因此让我们使用DistilBert标记器
  24. from transformers import DistilBertTokenizerFast
  25. tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
  26. train_encodings = tokenizer(train_texts, truncation=True, padding=True)
  27. val_encodings = tokenizer(val_texts, truncation=True, padding=True)
  28. test_encodings = tokenizer(test_texts, truncation=True, padding=True)
  29. import torch
  30. class IMDbDataset(torch.utils.data.Dataset):
  31. def __init__(self, encodings, labels):
  32. self.encodings = encodings
  33. self.labels = labels
  34. def __getitem__(self, idx):
  35. item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  36. item['labels'] = torch.tensor(self.labels[idx])
  37. return item
  38. def __len__(self):
  39. return len(self.labels)
  40. train_dataset = IMDbDataset(train_encodings, train_labels)
  41. val_dataset = IMDbDataset(val_encodings, val_labels)
  42. test_dataset = IMDbDataset(test_encodings, test_labels)
  43. #与培训师进行微调
  44. from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
  45. training_args = TrainingArguments(
  46. output_dir='./results', # output directory
  47. num_train_epochs=3, # total number of training epochs
  48. per_device_train_batch_size=16, # batch size per device during training
  49. per_device_eval_batch_size=64, # batch size for evaluation
  50. warmup_steps=500, # number of warmup steps for learning rate scheduler
  51. weight_decay=0.01, # strength of weight decay
  52. logging_dir='./logs', # directory for storing logs
  53. logging_steps=10,
  54. )
  55. model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
  56. trainer = Trainer(
  57. model=model, # the instantiated
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/347526
    推荐阅读
    相关标签