赞
踩
!pip install transformers datasets evaluate accelerate
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer
ds = Dataset.load_from_disk("./wiki_cn_filtered/")
# 同上
# ds = load_from_disk("./wiki_cn_filtered/")
ds[0]
'''
{'source': 'wikipedia.zh2307',
'completion': "西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆,馆长是锺明善。\n历史\n2004年9月20日开始筹建,2013年4月8日正式建成开馆,位于西安交通大学兴庆校区陕西省西安市咸宁西路28号。建筑面积6,800平米,展厅面积4,500平米,馆藏文物4,900余件。包括历代艺术文物馆、碑石书法馆、西部农民画馆、邢良坤陶瓷艺术馆、陕西秦腔博物馆和书画展厅共五馆一厅。\n营业时间\n* 周一至周六:上午九点至十二点,下午一点至五点\n* 周日闭馆"}
'''
padding=True
都可以,DataCollatorForLanguageModeling
会 padding
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")
def process_func(examples):
return tokenizer(examples["completion"], max_length=384, truncation=True)
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
tokenized_ds
'''
Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 10000
})
'''
DataCollatorForLanguageModeling
都给处理了from torch.utils.data import DataLoader
# mlm:是否要做MLM任务,默认为True
# mlm_probability:指定mlm任务中的mask token的概率
dl = DataLoader(tokenized_ds, batch_size=2,
collate_fn=DataCollatorForLanguageModeling(tokenizer,
mlm=True,
mlm_probability=0.15))
next(enumerate(dl))
'''
(0,
{'input_ids': tensor([[ 101, 6205, 103, 769, 6858, 1920, 103, 1300, 4289, 103,
8020, 13135, 112, 9064, 12095, 8731, 103, 8181, 8736, 10553,
8021, 3221, 671, 2429, 855, 754, 6205, 2128, 769, 6858,
103, 2110, 4638, 1300, 4289, 7667, 8024, 7667, 7270, 3221,
7247, 3209, 1587, 511, 1325, 1380, 8258, 103, 130, 3299,
8113, 3189, 2458, 1993, 5040, 2456, 8024, 8138, 2399, 125,
3299, 129, 3189, 3633, 103, 2456, 2768, 2458, 7667, 8024,
855, 754, 103, 2128, 103, 103, 1920, 2110, 1069, 2412,
3413, 1277, 7362, 6205, 4689, 6205, 2128, 2356, 103, 103,
6205, 6662, 8143, 1384, 511, 2456, 19679, 7481, 103, 127,
117, 8280, 2398, 5101, 103, 2245, 1324, 103, 4916, 125,
117, 8195, 2398, 5101, 8024, 7667, 5966, 3152, 4289, 125,
117, 8567, 865, 816, 511, 1259, 2886, 1325, 807, 5686,
103, 3152, 4289, 7667, 510, 4811, 4767, 741, 16272, 7667,
510, 6205, 6956, 103, 3696, 4514, 7667, 510, 103, 5679,
1787, 7378, 4487, 5686, 103, 7667, 510, 7362, 6205, 4912,
5579, 1300, 4289, 7667, 1469, 741, 4514, 2245, 103, 103,
103, 7667, 671, 6908, 103, 5852, 689, 3198, 14303, 115,
103, 671, 5635, 1453, 1063, 8038, 677, 1286, 736, 4157,
5635, 1282, 753, 4157, 103, 678, 103, 671, 4157, 103,
758, 4157, 115, 1453, 103, 7308, 7667, 102, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 101, 103, 1921, 712, 3136, 833, 8024, 712, 3136, 1730,
8020, 10676, 103, 8588, 9401, 11075, 11485, 10967, 103, 8321,
8024, 4684, 6406, 711, 712, 3136, 833, 6379, 8021, 3221,
4294, 2137, 1765, 1277, 4638, 712, 3136, 103, 2792, 5299,
2768, 4638, 6379, 752, 3322, 3354, 8024, 6858, 2382, 868,
711, 6821, 763, 702, 1166, 3136, 833, 8027, 1765, 3175,
3136, 833, 4638, 5632, 3780, 3322, 3354, 511, 7478, 2382,
2578, 2595, 4638, 712, 3136, 17078, 6379, 2100, 1762, 103,
719, 8024, 852, 712, 3136, 1730, 868, 711, 671, 4905,
3322, 3354, 103, 1798, 3221, 794, 3464, 5881, 1082, 5018,
753, 2237, 1920, 1062, 103, 6379, 7562, 2357, 749, 515,
712, 3136, 1762, 3136, 833, 1079, 4288, 4130, 5466, 1218,
103, 3791, 103, 722, 1400, 2458, 1993, 4638, 8024, 2400,
1762, 103, 2134, 924, 4882, 1063, 686, 754, 9093, 2399,
5041, 5392, 517, 1760, 3136, 833, 518, 8020, 12470, 103,
12557, 103, 9542, 8722, 10361, 8021, 5632, 1220, 6405, 741,
1400, 2141, 103, 511, 4680, 1184, 103, 3136, 1730, 4638,
6817, 868, 103, 3326, 1213, 1469, 6569, 103, 8024, 6963,
6226, 5745, 1762, 8715, 2399, 4276, 517, 1921, 712, 3136,
3791, 1073, 518, 1079, 8020, 5018, 8252, 8161, 100, 8208,
8160, 3340, 8021, 511, 103, 3136, 1730, 4638, 3891, 4667,
5745, 1741, 6858, 2382, 3221, 898, 103, 1765, 103, 6804,
4518, 3341, 2137, 721, 8024, 103, 1914, 809, 1744, 2157,
711, 103, 103, 8024, 3300, 3198, 103, 833, 4507, 3144,
702, 103, 2157, 17435, 2768, 671, 702, 712, 3136, 103,
511, 4294, 2137, 4638, 2339, 868, 1469, 3326, 1213, 833,
103, 678, 3123, 5314, 712, 3136, 103, 103, 103, 1166,
3221, 1068, 103, 2477, 3054, 4638, 4851, 811, 6226, 5745,
511, 712, 103, 1730, 1762, 3249, 6881, 6226, 2526, 2772,
4294, 2137, 103, 1218, 4638, 2956, 9834, 678, 5815, 2533,
3326, 1213, 511, 898, 4212, 3136, 833, 3791, 2137, 721,
8024, 1762, 3378, 763, 2658, 1105, 678, 8024, 103, 3136,
1730, 4638, 1104, 2137, 7444, 1358, 103, 1760, 2429, 103,
2821, 1114, 511, 1392, 103, 16090, 3136, 679, 7444, 103,
2461, 800, 812, 1762, 712, 3136, 1730, 1079, 4638, 3326,
1213, 8024, 5445, 103, 5330, 6566, 6569, 3780, 4415, 103,
5632, 4638, 172, 1277, 511, 1154, 6134, 3315, 1154, 6134,
898, 4212, 103, 7650, 517, 2134, 2429, 2399, 7063, 518,
722, 6381, 6770, 102]]),
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ -100, -100, 2128, -100, -100, -100, 2110, -100, -100, 7667,
-100, -100, -100, -100, -100, -100, 8626, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
1920, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, 2399, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, 2466, -100, -100, -100, -100, -100,
-100, -100, 6205, -100, 769, 6858, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, 1496, 2123,
-100, -100, -100, -100, -100, -100, 5029, -100, 4916, -100,
-100, -100, -100, -100, 8024, -100, -100, 7481, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
3318, -100, 4289, -100, -100, -100, -100, -100, 3791, -100,
-100, -100, -100, 1093, -100, -100, -100, -100, 6928, -100,
-100, -100, -100, -100, 3318, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, 1324, 1066,
758, -100, -100, 1324, 511, -100, -100, -100, 7313, -100,
1453, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, 8024, -100, 1286, -100, -100, 5635,
-100, -100, -100, -100, 3189, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100],
[ -100, 1762, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 8331, -100, -100, -100, -100, -100, 8936, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, 1079, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, 833, -100, -100, -100, 2347,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 5102, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, 833, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
516, -100, 808, -100, -100, -100, -100, -100, -100, -100,
-100, 3136, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, 11619,
-100, 8154, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 3177, -100, -100, -100, 712, -100, -100, -100,
-100, -100, 510, -100, -100, -100, -100, 818, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
3791, -100, -100, -100, 8020, -100, -100, -100, -100, -100,
-100, -100, -100, -100, 712, -100, -100, -100, -100, -100,
-100, 1741, -100, -100, -100, -100, 4212, -100, 4415, -100,
-100, -100, -100, -100, -100, 1920, -100, -100, -100, -100,
-100, 1296, 855, -100, -100, -100, 738, -100, -100, -100,
-100, 1744, -100, 5299, -100, -100, -100, -100, -100, 1730,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
6158, -100, -100, -100, -100, -100, 1730, 8024, 4294, -100,
3221, -100, 754, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 3136, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 818, -100, -100, -100, 3326, -100, -100, 2533,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, 678, -100, 712, -100,
-100, -100, -100, -100, -100, -100, 1168, -100, -100, 4638,
-100, -100, -100, -100, 702, 712, -100, -100, -100, 3123,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, 5326, -100, -100, -100, -100, -100, 1392,
-100, -100, 3136, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 8166, 2399, -100, -100, -100, -100, -100, 518,
-100, -100, -100, -100]])})
'''
tokenizer.mask_token, tokenizer.mask_token_id
'''
('[MASK]', 103)
'''
AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("hfl/chinese-macbert-base")
args = TrainingArguments(
output_dir="./masked_lm",
per_device_train_batch_size=32,
logging_steps=10,
num_train_epochs=1
)
trainer = Trainer(
args=args,
model=model,
train_dataset=tokenized_ds,
data_collator=DataCollatorForLanguageModeling(tokenizer,
mlm=True,
mlm_probability=0.15)
)
trainer.train()
from transformers import pipeline
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
pipe("西安交通[MASK][MASK]博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆")
'''
[[{'score': 0.9948391318321228,
'token': 1920,
'token_str': '大',
'sequence': "[CLS] 西 安 交 通 大 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0021944143809378147,
'token': 2110,
'token_str': '学',
'sequence': "[CLS] 西 安 交 通 学 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0005357894115149975,
'token': 2339,
'token_str': '工',
'sequence': "[CLS] 西 安 交 通 工 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0003660529328044504,
'token': 7770,
'token_str': '高',
'sequence': "[CLS] 西 安 交 通 高 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0003039983275812119,
'token': 4906,
'token_str': '科',
'sequence': "[CLS] 西 安 交 通 科 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"}],
[{'score': 0.9971689581871033,
'token': 2110,
'token_str': '学',
'sequence': "[CLS] 西 安 交 通 [MASK] 学 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0011966773308813572,
'token': 7368,
'token_str': '院',
'sequence': "[CLS] 西 安 交 通 [MASK] 院 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0008860519737936556,
'token': 1920,
'token_str': '大',
'sequence': "[CLS] 西 安 交 通 [MASK] 大 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 0.0001224309962708503,
'token': 3318,
'token_str': '术',
'sequence': "[CLS] 西 安 交 通 [MASK] 术 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
{'score': 8.154550596373156e-05,
'token': 4289,
'token_str': '物',
'sequence': "[CLS] 西 安 交 通 [MASK] 物 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"}]]
'''
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, BloomForCausalLM
ds = Dataset.load_from_disk("./wiki_cn_filtered/")
ds
'''
Dataset({
features: ['source', 'completion'],
num_rows: 10000
})
'''
ds[0]
'''
{'source': 'wikipedia.zh2307',
'completion': "西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学的博物馆,馆长是锺明善。\n历史\n2004年9月20日开始筹建,2013年4月8日正式建成开馆,位于西安交通大学兴庆校区陕西省西安市咸宁西路28号。建筑面积6,800平米,展厅面积4,500平米,馆藏文物4,900余件。包括历代艺术文物馆、碑石书法馆、西部农民画馆、邢良坤陶瓷艺术馆、陕西秦腔博物馆和书画展厅共五馆一厅。\n营业时间\n* 周一至周六:上午九点至十二点,下午一点至五点\n* 周日闭馆"}
'''
需要加 end_token(eos),告诉模型什么时候结束生成
不需要 padding,DataCollatorForLanguageModeling
做因果语言模型时,自动左填充,即使增加了 padding=True
还是会左填充
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")
def process_func(examples):
contents = [e + tokenizer.eos_token for e in examples["completion"]]
return tokenizer(contents, max_length=384, truncation=True)
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
tokenized_ds
'''
Dataset({
features: ['input_ids', 'attention_mask'],
num_rows: 10000
})
'''
from torch.utils.data import DataLoader
dl = DataLoader(tokenized_ds, batch_size=2, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False))
next(enumerate(dl))
'''
(0,
{'input_ids': tensor([[ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 13110, 34800,
13535, 916, 33156, 10, 256, 576, 387, 479, 681, 5453,
10955, 915, 24124, 5317, 13110, 6573, 20757, 13535, 355, 5358,
1490, 583, 28056, 1407, 3855, 671, 6113, 189, 6732, 4302,
9488, 3434, 6900, 1322, 355, 37336, 9825, 4608, 13461, 1359,
5358, 355, 5317, 13110, 34800, 4433, 7189, 25722, 29747, 13110,
1498, 12047, 6347, 23563, 2139, 2066, 420, 29288, 25, 15,
7635, 39288, 355, 1484, 5835, 6272, 23, 15, 4180, 39288,
355, 5358, 4516, 11621, 23, 15, 10641, 4887, 1712, 420,
2450, 31163, 8085, 11621, 5358, 553, 9888, 2731, 21335, 5358,
553, 9876, 14011, 4434, 5358, 553, 21484, 4514, 17170, 25871,
8085, 5358, 553, 17489, 6945, 11097, 13535, 641, 33623, 1484,
5835, 1689, 2063, 5358, 569, 5835, 671, 16225, 3422, 189,
13, 23158, 27894, 33227, 1022, 11396, 3347, 1813, 1504, 6566,
1813, 355, 9155, 8633, 1504, 2063, 1813, 189, 13, 23158,
813, 7817, 5358, 2],
[ 586, 11835, 739, 355, 8417, 2300, 916, 16892, 5077, 580,
15549, 1434, 996, 307, 387, 355, 1997, 8236, 775, 8417,
2152, 6496, 12176, 3728, 7692, 1503, 1528, 1014, 17887, 18001,
3250, 355, 6250, 3896, 2670, 20882, 16080, 14005, 3993, 1503,
12295, 8930, 3250, 420, 4208, 4326, 5367, 8417, 2152, 4632,
37591, 355, 1379, 8417, 2300, 32824, 3250, 9015, 18945, 16714,
5908, 13551, 30330, 23756, 2152, 15976, 657, 5923, 8417, 586,
16080, 1528, 8941, 5917, 14035, 5895, 14092, 4353, 29445, 355,
8790, 21595, 1450, 15911, 31116, 10345, 29940, 10874, 1125, 4829,
16080, 7093, 22939, 737, 262, 387, 72, 272, 831, 455,
72, 915, 8860, 20725, 1934, 1084, 5478, 420, 4415, 8417,
26263, 12726, 553, 10875, 34820, 355, 1266, 5498, 586, 14907,
32795, 11835, 904, 19934, 1528, 19531, 20517, 1349, 19472, 28879,
671, 8417, 26263, 17020, 5963, 22388, 11900, 12669, 13240, 1042,
9783, 355, 7242, 714, 1806, 775, 6500, 355, 11526, 10185,
1293, 1665, 15984, 7092, 1617, 8417, 2300, 420, 19972, 25622,
10875, 17500, 26523, 2391, 8417, 2300, 355, 6751, 2836, 13539,
8247, 373, 30201, 5498, 420, 8417, 2300, 586, 8523, 19358,
1298, 12176, 30939, 10739, 964, 4318, 10875, 420, 11900, 16080,
904, 9783, 355, 22464, 9658, 355, 8417, 2300, 13561, 2054,
4983, 4829, 30800, 7262, 420, 11394, 8417, 35143, 11937, 15682,
8417, 2300, 7283, 10875, 355, 1016, 4179, 5039, 14027, 26215,
26835, 671, 15095, 189, 1165, 15095, 11900, 6184, 1125, 3244,
3687, 622, 8785, 1121, 891, 13765, 671, 10199, 189, 13,
210, 6940, 3728, 9552, 6082, 8417, 2300, 916, 4375, 714,
3679, 1806, 10567, 915, 189, 13, 210, 16131, 11835, 8417,
2300, 189, 13, 24075, 3728, 8417, 2300, 916, 4829, 3687,
9000, 27689, 8417, 2300, 1300, 11243, 2062, 28431, 27689, 11835,
8417, 2300, 13224, 4829, 3687, 15964, 915, 189, 13, 27340,
11835, 8417, 2300, 189, 13, 27340, 3982, 3728, 8417, 2300,
916, 4375, 2450, 3272, 1234, 19083, 553, 3512, 3121, 1728,
641, 3092, 2113, 7843, 915, 189, 13, 210, 15402, 8417,
2300, 189, 13, 10057, 108, 12693, 14624, 29379, 4719, 6533,
739, 916, 16148, 10981, 21350, 9067, 1203, 8931, 1258, 11835,
4719, 6533, 739, 20393, 189, 13, 210, 19546, 1517, 11835,
8417, 2300, 189, 13, 210, 23928, 168, 117, 245, 6279,
114, 240, 170, 100, 124, 168, 117, 228, 6279, 100,
124, 168, 117, 228, 171, 238, 224, 41356, 236, 24175,
11082, 10981, 21350, 9067]]),
'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, 13110, 34800,
13535, 916, 33156, 10, 256, 576, 387, 479, 681, 5453,
10955, 915, 24124, 5317, 13110, 6573, 20757, 13535, 355, 5358,
1490, 583, 28056, 1407, 3855, 671, 6113, 189, 6732, 4302,
9488, 3434, 6900, 1322, 355, 37336, 9825, 4608, 13461, 1359,
5358, 355, 5317, 13110, 34800, 4433, 7189, 25722, 29747, 13110,
1498, 12047, 6347, 23563, 2139, 2066, 420, 29288, 25, 15,
7635, 39288, 355, 1484, 5835, 6272, 23, 15, 4180, 39288,
355, 5358, 4516, 11621, 23, 15, 10641, 4887, 1712, 420,
2450, 31163, 8085, 11621, 5358, 553, 9888, 2731, 21335, 5358,
553, 9876, 14011, 4434, 5358, 553, 21484, 4514, 17170, 25871,
8085, 5358, 553, 17489, 6945, 11097, 13535, 641, 33623, 1484,
5835, 1689, 2063, 5358, 569, 5835, 671, 16225, 3422, 189,
13, 23158, 27894, 33227, 1022, 11396, 3347, 1813, 1504, 6566,
1813, 355, 9155, 8633, 1504, 2063, 1813, 189, 13, 23158,
813, 7817, 5358, 2],
[ 586, 11835, 739, 355, 8417, 2300, 916, 16892, 5077, 580,
15549, 1434, 996, 307, 387, 355, 1997, 8236, 775, 8417,
2152, 6496, 12176, 3728, 7692, 1503, 1528, 1014, 17887, 18001,
3250, 355, 6250, 3896, 2670, 20882, 16080, 14005, 3993, 1503,
12295, 8930, 3250, 420, 4208, 4326, 5367, 8417, 2152, 4632,
37591, 355, 1379, 8417, 2300, 32824, 3250, 9015, 18945, 16714,
5908, 13551, 30330, 23756, 2152, 15976, 657, 5923, 8417, 586,
16080, 1528, 8941, 5917, 14035, 5895, 14092, 4353, 29445, 355,
8790, 21595, 1450, 15911, 31116, 10345, 29940, 10874, 1125, 4829,
16080, 7093, 22939, 737, 262, 387, 72, 272, 831, 455,
72, 915, 8860, 20725, 1934, 1084, 5478, 420, 4415, 8417,
26263, 12726, 553, 10875, 34820, 355, 1266, 5498, 586, 14907,
32795, 11835, 904, 19934, 1528, 19531, 20517, 1349, 19472, 28879,
671, 8417, 26263, 17020, 5963, 22388, 11900, 12669, 13240, 1042,
9783, 355, 7242, 714, 1806, 775, 6500, 355, 11526, 10185,
1293, 1665, 15984, 7092, 1617, 8417, 2300, 420, 19972, 25622,
10875, 17500, 26523, 2391, 8417, 2300, 355, 6751, 2836, 13539,
8247, 373, 30201, 5498, 420, 8417, 2300, 586, 8523, 19358,
1298, 12176, 30939, 10739, 964, 4318, 10875, 420, 11900, 16080,
904, 9783, 355, 22464, 9658, 355, 8417, 2300, 13561, 2054,
4983, 4829, 30800, 7262, 420, 11394, 8417, 35143, 11937, 15682,
8417, 2300, 7283, 10875, 355, 1016, 4179, 5039, 14027, 26215,
26835, 671, 15095, 189, 1165, 15095, 11900, 6184, 1125, 3244,
3687, 622, 8785, 1121, 891, 13765, 671, 10199, 189, 13,
210, 6940, 3728, 9552, 6082, 8417, 2300, 916, 4375, 714,
3679, 1806, 10567, 915, 189, 13, 210, 16131, 11835, 8417,
2300, 189, 13, 24075, 3728, 8417, 2300, 916, 4829, 3687,
9000, 27689, 8417, 2300, 1300, 11243, 2062, 28431, 27689, 11835,
8417, 2300, 13224, 4829, 3687, 15964, 915, 189, 13, 27340,
11835, 8417, 2300, 189, 13, 27340, 3982, 3728, 8417, 2300,
916, 4375, 2450, 3272, 1234, 19083, 553, 3512, 3121, 1728,
641, 3092, 2113, 7843, 915, 189, 13, 210, 15402, 8417,
2300, 189, 13, 10057, 108, 12693, 14624, 29379, 4719, 6533,
739, 916, 16148, 10981, 21350, 9067, 1203, 8931, 1258, 11835,
4719, 6533, 739, 20393, 189, 13, 210, 19546, 1517, 11835,
8417, 2300, 189, 13, 210, 23928, 168, 117, 245, 6279,
114, 240, 170, 100, 124, 168, 117, 228, 6279, 100,
124, 168, 117, 228, 171, 238, 224, 41356, 236, 24175,
11082, 10981, 21350, 9067]])})
'''
tokenizer.pad_token, tokenizer.pad_token_id
'''
('<pad>', 3)
'''
tokenizer.eos_token, tokenizer.eos_token_id
'''
('</s>', 2)
'''
model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")
args = TrainingArguments(
output_dir="./causal_lm",
per_device_train_batch_size=4, # 相当于per_device_train_batch_size=32,只不过慢
gradient_accumulation_steps=8,
logging_steps=10,
num_train_epochs=1
)
trainer = Trainer(
args=args,
model=model,
train_dataset=tokenized_ds,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
trainer.train()
from transformers import pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
pipe("西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安", max_length=128, do_sample=True)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。