当前位置:   article > 正文

Transformers实战——预训练模型_transformers gptq 如何导入数据集

transformers gptq 如何导入数据集

!pip install transformers datasets evaluate accelerate 
  • 1

一、掩码语言模型


1.导入相关包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer
  • 1
  • 2

2.加载数据集

ds = Dataset.load_from_disk("./wiki_cn_filtered/")

# 同上
# ds = load_from_disk("./wiki_cn_filtered/")
  • 1
  • 2
  • 3
  • 4
ds[0]
'''
{'source': 'wikipedia.zh2307',
 'completion': "西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆,馆长是锺明善。\n历史\n2004年9月20日开始筹建,2013年4月8日正式建成开馆,位于西安交通大学兴庆校区陕西省西安市咸宁西路28号。建筑面积6,800平米,展厅面积4,500平米,馆藏文物4,900余件。包括历代艺术文物馆、碑石书法馆、西部农民画馆、邢良坤陶瓷艺术馆、陕西秦腔博物馆和书画展厅共五馆一厅。\n营业时间\n* 周一至周六:上午九点至十二点,下午一点至五点\n* 周日闭馆"}
'''
  • 1
  • 2
  • 3
  • 4
  • 5

3.数据集处理

  • 注意:这里指定不指定 padding=True 都可以,DataCollatorForLanguageModelingpadding
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")

def process_func(examples):
    return tokenizer(examples["completion"], max_length=384, truncation=True)
  • 1
  • 2
  • 3
  • 4
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
tokenized_ds
'''
Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 10000
})
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • mask 以及 label 部分 DataCollatorForLanguageModeling都给处理了
from torch.utils.data import DataLoader

# mlm:是否要做MLM任务,默认为True
# mlm_probability:指定mlm任务中的mask token的概率
dl = DataLoader(tokenized_ds, batch_size=2, 
                collate_fn=DataCollatorForLanguageModeling(tokenizer, 
                                                           mlm=True, 
                                                           mlm_probability=0.15))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
next(enumerate(dl))
'''
(0,
 {'input_ids': tensor([[  101,  6205,   103,   769,  6858,  1920,   103,  1300,  4289,   103,
           8020, 13135,   112,  9064, 12095,  8731,   103,  8181,  8736, 10553,
           8021,  3221,   671,  2429,   855,   754,  6205,  2128,   769,  6858,
            103,  2110,  4638,  1300,  4289,  7667,  8024,  7667,  7270,  3221,
           7247,  3209,  1587,   511,  1325,  1380,  8258,   103,   130,  3299,
           8113,  3189,  2458,  1993,  5040,  2456,  8024,  8138,  2399,   125,
           3299,   129,  3189,  3633,   103,  2456,  2768,  2458,  7667,  8024,
            855,   754,   103,  2128,   103,   103,  1920,  2110,  1069,  2412,
           3413,  1277,  7362,  6205,  4689,  6205,  2128,  2356,   103,   103,
           6205,  6662,  8143,  1384,   511,  2456, 19679,  7481,   103,   127,
            117,  8280,  2398,  5101,   103,  2245,  1324,   103,  4916,   125,
            117,  8195,  2398,  5101,  8024,  7667,  5966,  3152,  4289,   125,
            117,  8567,   865,   816,   511,  1259,  2886,  1325,   807,  5686,
            103,  3152,  4289,  7667,   510,  4811,  4767,   741, 16272,  7667,
            510,  6205,  6956,   103,  3696,  4514,  7667,   510,   103,  5679,
           1787,  7378,  4487,  5686,   103,  7667,   510,  7362,  6205,  4912,
           5579,  1300,  4289,  7667,  1469,   741,  4514,  2245,   103,   103,
            103,  7667,   671,  6908,   103,  5852,   689,  3198, 14303,   115,
            103,   671,  5635,  1453,  1063,  8038,   677,  1286,   736,  4157,
           5635,  1282,   753,  4157,   103,   678,   103,   671,  4157,   103,
            758,  4157,   115,  1453,   103,  7308,  7667,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0],
         [  101,   103,  1921,   712,  3136,   833,  8024,   712,  3136,  1730,
           8020, 10676,   103,  8588,  9401, 11075, 11485, 10967,   103,  8321,
           8024,  4684,  6406,   711,   712,  3136,   833,  6379,  8021,  3221,
           4294,  2137,  1765,  1277,  4638,   712,  3136,   103,  2792,  5299,
           2768,  4638,  6379,   752,  3322,  3354,  8024,  6858,  2382,   868,
            711,  6821,   763,   702,  1166,  3136,   833,  8027,  1765,  3175,
           3136,   833,  4638,  5632,  3780,  3322,  3354,   511,  7478,  2382,
           2578,  2595,  4638,   712,  3136, 17078,  6379,  2100,  1762,   103,
            719,  8024,   852,   712,  3136,  1730,   868,   711,   671,  4905,
           3322,  3354,   103,  1798,  3221,   794,  3464,  5881,  1082,  5018,
            753,  2237,  1920,  1062,   103,  6379,  7562,  2357,   749,   515,
            712,  3136,  1762,  3136,   833,  1079,  4288,  4130,  5466,  1218,
            103,  3791,   103,   722,  1400,  2458,  1993,  4638,  8024,  2400,
           1762,   103,  2134,   924,  4882,  1063,   686,   754,  9093,  2399,
           5041,  5392,   517,  1760,  3136,   833,   518,  8020, 12470,   103,
          12557,   103,  9542,  8722, 10361,  8021,  5632,  1220,  6405,   741,
           1400,  2141,   103,   511,  4680,  1184,   103,  3136,  1730,  4638,
           6817,   868,   103,  3326,  1213,  1469,  6569,   103,  8024,  6963,
           6226,  5745,  1762,  8715,  2399,  4276,   517,  1921,   712,  3136,
           3791,  1073,   518,  1079,  8020,  5018,  8252,  8161,   100,  8208,
           8160,  3340,  8021,   511,   103,  3136,  1730,  4638,  3891,  4667,
           5745,  1741,  6858,  2382,  3221,   898,   103,  1765,   103,  6804,
           4518,  3341,  2137,   721,  8024,   103,  1914,   809,  1744,  2157,
            711,   103,   103,  8024,  3300,  3198,   103,   833,  4507,  3144,
            702,   103,  2157, 17435,  2768,   671,   702,   712,  3136,   103,
            511,  4294,  2137,  4638,  2339,   868,  1469,  3326,  1213,   833,
            103,   678,  3123,  5314,   712,  3136,   103,   103,   103,  1166,
           3221,  1068,   103,  2477,  3054,  4638,  4851,   811,  6226,  5745,
            511,   712,   103,  1730,  1762,  3249,  6881,  6226,  2526,  2772,
           4294,  2137,   103,  1218,  4638,  2956,  9834,   678,  5815,  2533,
           3326,  1213,   511,   898,  4212,  3136,   833,  3791,  2137,   721,
           8024,  1762,  3378,   763,  2658,  1105,   678,  8024,   103,  3136,
           1730,  4638,  1104,  2137,  7444,  1358,   103,  1760,  2429,   103,
           2821,  1114,   511,  1392,   103, 16090,  3136,   679,  7444,   103,
           2461,   800,   812,  1762,   712,  3136,  1730,  1079,  4638,  3326,
           1213,  8024,  5445,   103,  5330,  6566,  6569,  3780,  4415,   103,
           5632,  4638,   172,  1277,   511,  1154,  6134,  3315,  1154,  6134,
            898,  4212,   103,  7650,   517,  2134,  2429,  2399,  7063,   518,
            722,  6381,  6770,   102]]), 
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  2128,  -100,  -100,  -100,  2110,  -100,  -100,  7667,
           -100,  -100,  -100,  -100,  -100,  -100,  8626,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           1920,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  2399,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  2466,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  6205,  -100,   769,  6858,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1496,  2123,
           -100,  -100,  -100,  -100,  -100,  -100,  5029,  -100,  4916,  -100,
           -100,  -100,  -100,  -100,  8024,  -100,  -100,  7481,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           3318,  -100,  4289,  -100,  -100,  -100,  -100,  -100,  3791,  -100,
           -100,  -100,  -100,  1093,  -100,  -100,  -100,  -100,  6928,  -100,
           -100,  -100,  -100,  -100,  3318,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1324,  1066,
            758,  -100,  -100,  1324,   511,  -100,  -100,  -100,  7313,  -100,
           1453,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  8024,  -100,  1286,  -100,  -100,  5635,
           -100,  -100,  -100,  -100,  3189,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100],
         [ -100,  1762,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  8331,  -100,  -100,  -100,  -100,  -100,  8936,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  1079,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,   833,  -100,  -100,  -100,  2347,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  5102,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,   833,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
            516,  -100,   808,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  3136,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 11619,
           -100,  8154,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  3177,  -100,  -100,  -100,   712,  -100,  -100,  -100,
           -100,  -100,   510,  -100,  -100,  -100,  -100,   818,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           3791,  -100,  -100,  -100,  8020,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,   712,  -100,  -100,  -100,  -100,  -100,
           -100,  1741,  -100,  -100,  -100,  -100,  4212,  -100,  4415,  -100,
           -100,  -100,  -100,  -100,  -100,  1920,  -100,  -100,  -100,  -100,
           -100,  1296,   855,  -100,  -100,  -100,   738,  -100,  -100,  -100,
           -100,  1744,  -100,  5299,  -100,  -100,  -100,  -100,  -100,  1730,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           6158,  -100,  -100,  -100,  -100,  -100,  1730,  8024,  4294,  -100,
           3221,  -100,   754,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  3136,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,   818,  -100,  -100,  -100,  3326,  -100,  -100,  2533,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,   678,  -100,   712,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  1168,  -100,  -100,  4638,
           -100,  -100,  -100,  -100,   702,   712,  -100,  -100,  -100,  3123,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  5326,  -100,  -100,  -100,  -100,  -100,  1392,
           -100,  -100,  3136,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  8166,  2399,  -100,  -100,  -100,  -100,  -100,   518,
           -100,  -100,  -100,  -100]])})
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
tokenizer.mask_token, tokenizer.mask_token_id
'''
('[MASK]', 103)
'''
  • 1
  • 2
  • 3
  • 4

4.创建模型

  • 使用 AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("hfl/chinese-macbert-base")
  • 1

5.配置训练参数

args = TrainingArguments(
    output_dir="./masked_lm",
    per_device_train_batch_size=32,
    logging_steps=10,
    num_train_epochs=1
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

6.创建训练器

trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForLanguageModeling(tokenizer, 
                                                  mlm=True, 
                                                  mlm_probability=0.15)
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

7.模型训练

trainer.train()
  • 1

8.模型推理

from transformers import pipeline

pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
  • 1
  • 2
  • 3
pipe("西安交通[MASK][MASK]博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆")
'''
[[{'score': 0.9948391318321228,
   'token': 1920,
   'token_str': '大',
   'sequence': "[CLS] 西 安 交 通 大 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0021944143809378147,
   'token': 2110,
   'token_str': '学',
   'sequence': "[CLS] 西 安 交 通 学 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0005357894115149975,
   'token': 2339,
   'token_str': '工',
   'sequence': "[CLS] 西 安 交 通 工 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0003660529328044504,
   'token': 7770,
   'token_str': '高',
   'sequence': "[CLS] 西 安 交 通 高 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0003039983275812119,
   'token': 4906,
   'token_str': '科',
   'sequence': "[CLS] 西 安 交 通 科 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"}],
 [{'score': 0.9971689581871033,
   'token': 2110,
   'token_str': '学',
   'sequence': "[CLS] 西 安 交 通 [MASK] 学 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0011966773308813572,
   'token': 7368,
   'token_str': '院',
   'sequence': "[CLS] 西 安 交 通 [MASK] 院 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0008860519737936556,
   'token': 1920,
   'token_str': '大',
   'sequence': "[CLS] 西 安 交 通 [MASK] 大 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0001224309962708503,
   'token': 3318,
   'token_str': '术',
   'sequence': "[CLS] 西 安 交 通 [MASK] 术 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 8.154550596373156e-05,
   'token': 4289,
   'token_str': '物',
   'sequence': "[CLS] 西 安 交 通 [MASK] 物 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"}]]
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

二、因果语言模型


1.导入相关包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, BloomForCausalLM
  • 1
  • 2

2.加载数据集

ds = Dataset.load_from_disk("./wiki_cn_filtered/")
ds
'''
Dataset({
    features: ['source', 'completion'],
    num_rows: 10000
})
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
ds[0]
'''
{'source': 'wikipedia.zh2307',
 'completion': "西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆,馆长是锺明善。\n历史\n2004年9月20日开始筹建,2013年4月8日正式建成开馆,位于西安交通大学兴庆校区陕西省西安市咸宁西路28号。建筑面积6,800平米,展厅面积4,500平米,馆藏文物4,900余件。包括历代艺术文物馆、碑石书法馆、西部农民画馆、邢良坤陶瓷艺术馆、陕西秦腔博物馆和书画展厅共五馆一厅。\n营业时间\n* 周一至周六:上午九点至十二点,下午一点至五点\n* 周日闭馆"}
'''
  • 1
  • 2
  • 3
  • 4
  • 5

3.数据集处理

需要加 end_token(eos),告诉模型什么时候结束生成
不需要 padding,DataCollatorForLanguageModeling做因果语言模型时,自动左填充,即使增加了 padding=True还是会左填充

tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")

def process_func(examples):
    contents = [e + tokenizer.eos_token for e in examples["completion"]]
    return tokenizer(contents, max_length=384, truncation=True)
  • 1
  • 2
  • 3
  • 4
  • 5
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
tokenized_ds
'''
Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 10000
})
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
from torch.utils.data import DataLoader

dl = DataLoader(tokenized_ds, batch_size=2, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False))
  • 1
  • 2
  • 3

next(enumerate(dl))
'''
(0,
 {'input_ids': tensor([[    3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3, 13110, 34800,
          13535,   916, 33156,    10,   256,   576,   387,   479,   681,  5453,
          10955,   915, 24124,  5317, 13110,  6573, 20757, 13535,   355,  5358,
           1490,   583, 28056,  1407,  3855,   671,  6113,   189,  6732,  4302,
           9488,  3434,  6900,  1322,   355, 37336,  9825,  4608, 13461,  1359,
           5358,   355,  5317, 13110, 34800,  4433,  7189, 25722, 29747, 13110,
           1498, 12047,  6347, 23563,  2139,  2066,   420, 29288,    25,    15,
           7635, 39288,   355,  1484,  5835,  6272,    23,    15,  4180, 39288,
            355,  5358,  4516, 11621,    23,    15, 10641,  4887,  1712,   420,
           2450, 31163,  8085, 11621,  5358,   553,  9888,  2731, 21335,  5358,
            553,  9876, 14011,  4434,  5358,   553, 21484,  4514, 17170, 25871,
           8085,  5358,   553, 17489,  6945, 11097, 13535,   641, 33623,  1484,
           5835,  1689,  2063,  5358,   569,  5835,   671, 16225,  3422,   189,
             13, 23158, 27894, 33227,  1022, 11396,  3347,  1813,  1504,  6566,
           1813,   355,  9155,  8633,  1504,  2063,  1813,   189,    13, 23158,
            813,  7817,  5358,     2],
         [  586, 11835,   739,   355,  8417,  2300,   916, 16892,  5077,   580,
          15549,  1434,   996,   307,   387,   355,  1997,  8236,   775,  8417,
           2152,  6496, 12176,  3728,  7692,  1503,  1528,  1014, 17887, 18001,
           3250,   355,  6250,  3896,  2670, 20882, 16080, 14005,  3993,  1503,
          12295,  8930,  3250,   420,  4208,  4326,  5367,  8417,  2152,  4632,
          37591,   355,  1379,  8417,  2300, 32824,  3250,  9015, 18945, 16714,
           5908, 13551, 30330, 23756,  2152, 15976,   657,  5923,  8417,   586,
          16080,  1528,  8941,  5917, 14035,  5895, 14092,  4353, 29445,   355,
           8790, 21595,  1450, 15911, 31116, 10345, 29940, 10874,  1125,  4829,
          16080,  7093, 22939,   737,   262,   387,    72,   272,   831,   455,
             72,   915,  8860, 20725,  1934,  1084,  5478,   420,  4415,  8417,
          26263, 12726,   553, 10875, 34820,   355,  1266,  5498,   586, 14907,
          32795, 11835,   904, 19934,  1528, 19531, 20517,  1349, 19472, 28879,
            671,  8417, 26263, 17020,  5963, 22388, 11900, 12669, 13240,  1042,
           9783,   355,  7242,   714,  1806,   775,  6500,   355, 11526, 10185,
           1293,  1665, 15984,  7092,  1617,  8417,  2300,   420, 19972, 25622,
          10875, 17500, 26523,  2391,  8417,  2300,   355,  6751,  2836, 13539,
           8247,   373, 30201,  5498,   420,  8417,  2300,   586,  8523, 19358,
           1298, 12176, 30939, 10739,   964,  4318, 10875,   420, 11900, 16080,
            904,  9783,   355, 22464,  9658,   355,  8417,  2300, 13561,  2054,
           4983,  4829, 30800,  7262,   420, 11394,  8417, 35143, 11937, 15682,
           8417,  2300,  7283, 10875,   355,  1016,  4179,  5039, 14027, 26215,
          26835,   671, 15095,   189,  1165, 15095, 11900,  6184,  1125,  3244,
           3687,   622,  8785,  1121,   891, 13765,   671, 10199,   189,    13,
            210,  6940,  3728,  9552,  6082,  8417,  2300,   916,  4375,   714,
           3679,  1806, 10567,   915,   189,    13,   210, 16131, 11835,  8417,
           2300,   189,    13, 24075,  3728,  8417,  2300,   916,  4829,  3687,
           9000, 27689,  8417,  2300,  1300, 11243,  2062, 28431, 27689, 11835,
           8417,  2300, 13224,  4829,  3687, 15964,   915,   189,    13, 27340,
          11835,  8417,  2300,   189,    13, 27340,  3982,  3728,  8417,  2300,
            916,  4375,  2450,  3272,  1234, 19083,   553,  3512,  3121,  1728,
            641,  3092,  2113,  7843,   915,   189,    13,   210, 15402,  8417,
           2300,   189,    13, 10057,   108, 12693, 14624, 29379,  4719,  6533,
            739,   916, 16148, 10981, 21350,  9067,  1203,  8931,  1258, 11835,
           4719,  6533,   739, 20393,   189,    13,   210, 19546,  1517, 11835,
           8417,  2300,   189,    13,   210, 23928,   168,   117,   245,  6279,
            114,   240,   170,   100,   124,   168,   117,   228,  6279,   100,
            124,   168,   117,   228,   171,   238,   224, 41356,   236, 24175,
          11082, 10981, 21350,  9067]]), 
'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 13110, 34800,
          13535,   916, 33156,    10,   256,   576,   387,   479,   681,  5453,
          10955,   915, 24124,  5317, 13110,  6573, 20757, 13535,   355,  5358,
           1490,   583, 28056,  1407,  3855,   671,  6113,   189,  6732,  4302,
           9488,  3434,  6900,  1322,   355, 37336,  9825,  4608, 13461,  1359,
           5358,   355,  5317, 13110, 34800,  4433,  7189, 25722, 29747, 13110,
           1498, 12047,  6347, 23563,  2139,  2066,   420, 29288,    25,    15,
           7635, 39288,   355,  1484,  5835,  6272,    23,    15,  4180, 39288,
            355,  5358,  4516, 11621,    23,    15, 10641,  4887,  1712,   420,
           2450, 31163,  8085, 11621,  5358,   553,  9888,  2731, 21335,  5358,
            553,  9876, 14011,  4434,  5358,   553, 21484,  4514, 17170, 25871,
           8085,  5358,   553, 17489,  6945, 11097, 13535,   641, 33623,  1484,
           5835,  1689,  2063,  5358,   569,  5835,   671, 16225,  3422,   189,
             13, 23158, 27894, 33227,  1022, 11396,  3347,  1813,  1504,  6566,
           1813,   355,  9155,  8633,  1504,  2063,  1813,   189,    13, 23158,
            813,  7817,  5358,     2],
         [  586, 11835,   739,   355,  8417,  2300,   916, 16892,  5077,   580,
          15549,  1434,   996,   307,   387,   355,  1997,  8236,   775,  8417,
           2152,  6496, 12176,  3728,  7692,  1503,  1528,  1014, 17887, 18001,
           3250,   355,  6250,  3896,  2670, 20882, 16080, 14005,  3993,  1503,
          12295,  8930,  3250,   420,  4208,  4326,  5367,  8417,  2152,  4632,
          37591,   355,  1379,  8417,  2300, 32824,  3250,  9015, 18945, 16714,
           5908, 13551, 30330, 23756,  2152, 15976,   657,  5923,  8417,   586,
          16080,  1528,  8941,  5917, 14035,  5895, 14092,  4353, 29445,   355,
           8790, 21595,  1450, 15911, 31116, 10345, 29940, 10874,  1125,  4829,
          16080,  7093, 22939,   737,   262,   387,    72,   272,   831,   455,
             72,   915,  8860, 20725,  1934,  1084,  5478,   420,  4415,  8417,
          26263, 12726,   553, 10875, 34820,   355,  1266,  5498,   586, 14907,
          32795, 11835,   904, 19934,  1528, 19531, 20517,  1349, 19472, 28879,
            671,  8417, 26263, 17020,  5963, 22388, 11900, 12669, 13240,  1042,
           9783,   355,  7242,   714,  1806,   775,  6500,   355, 11526, 10185,
           1293,  1665, 15984,  7092,  1617,  8417,  2300,   420, 19972, 25622,
          10875, 17500, 26523,  2391,  8417,  2300,   355,  6751,  2836, 13539,
           8247,   373, 30201,  5498,   420,  8417,  2300,   586,  8523, 19358,
           1298, 12176, 30939, 10739,   964,  4318, 10875,   420, 11900, 16080,
            904,  9783,   355, 22464,  9658,   355,  8417,  2300, 13561,  2054,
           4983,  4829, 30800,  7262,   420, 11394,  8417, 35143, 11937, 15682,
           8417,  2300,  7283, 10875,   355,  1016,  4179,  5039, 14027, 26215,
          26835,   671, 15095,   189,  1165, 15095, 11900,  6184,  1125,  3244,
           3687,   622,  8785,  1121,   891, 13765,   671, 10199,   189,    13,
            210,  6940,  3728,  9552,  6082,  8417,  2300,   916,  4375,   714,
           3679,  1806, 10567,   915,   189,    13,   210, 16131, 11835,  8417,
           2300,   189,    13, 24075,  3728,  8417,  2300,   916,  4829,  3687,
           9000, 27689,  8417,  2300,  1300, 11243,  2062, 28431, 27689, 11835,
           8417,  2300, 13224,  4829,  3687, 15964,   915,   189,    13, 27340,
          11835,  8417,  2300,   189,    13, 27340,  3982,  3728,  8417,  2300,
            916,  4375,  2450,  3272,  1234, 19083,   553,  3512,  3121,  1728,
            641,  3092,  2113,  7843,   915,   189,    13,   210, 15402,  8417,
           2300,   189,    13, 10057,   108, 12693, 14624, 29379,  4719,  6533,
            739,   916, 16148, 10981, 21350,  9067,  1203,  8931,  1258, 11835,
           4719,  6533,   739, 20393,   189,    13,   210, 19546,  1517, 11835,
           8417,  2300,   189,    13,   210, 23928,   168,   117,   245,  6279,
            114,   240,   170,   100,   124,   168,   117,   228,  6279,   100,
            124,   168,   117,   228,   171,   238,   224, 41356,   236, 24175,
          11082, 10981, 21350,  9067]])})
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
tokenizer.pad_token, tokenizer.pad_token_id
'''
('<pad>', 3)
'''
  • 1
  • 2
  • 3
  • 4
tokenizer.eos_token, tokenizer.eos_token_id
'''
('</s>', 2)
'''
  • 1
  • 2
  • 3
  • 4

4.创建模型

model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")
  • 1

5.配置训练参数

args = TrainingArguments(
    output_dir="./causal_lm",
    per_device_train_batch_size=4, # 相当于per_device_train_batch_size=32,只不过慢
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

6.创建 Trainer

trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

7.模型训练

trainer.train()
  • 1

8.模型推理

LLM大模型解码生成方式总结

from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
  • 1
  • 2
  • 3
pipe("西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安", max_length=128, do_sample=True)
  • 1
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号