赞
踩
由Transformers Trainer的文档中可知,Trainer函数有一个参数data_collator,其值也为一个函数,用于从一个list of elements来构造一个batch。这个函数其实就是torch.utils.data.DataLoader中的collate_fn。如果还不知道collate_fn是做何用处,请参考这篇文档。
要用到这个函数,要符合如下两个条件:
train_dataset
或eval_dataset
是torch.utils.data.Dataset或torch.utils.data.IterableDataset的实体train_dataset
或eval_dataset
(torch.utils.data.Dataset)加载入DataLoader后,得到的batch不可用,还不能直接加入到model的forward中计算这里假设读者已经知道torch.utils.data.DataLoader的collate_fn用法,只介绍Trainer的data_collator和torch.utils.data.DataLoader的collate_fn的差异。
差异就是,输出格式!torch.utils.data.DataLoader的collate_fn的输出可以是各种格式,但Trainer的data_collator的输出只能是一个dict,这个dict的键必须包含“input_ids”,“attention_mask”等transformers models正常运算必要的参数的名称,如果需要,也可以加入任何transformers model.forward()可接受的参数名,而这些键对应的值也必须是transformers model中该键应该对应的输入值。
如果想让模型自动训练loss,还要在这个dict中加入如下键值对:{“labels”: labels in tensor type},这样模型的输出里就有loss了。
看两段源码其实就差不多明白了:
第一张图中,这个DataLoader就是一个纯粹的torch.utils.data.DataLoader,self.data_collator就是输入的data_collator函数。所以,这个data_collator就彻彻底底是一个DataLoader的collate_fn啊
第二张图中,input就是如下迭代的结果(其中的dataloader就是第一张图中的dataloader)
for step, inputs in enumerate(DataLoader)
所以,inputs的键值对必须要与model.forwards()的参数相对应也是显然的
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。