赞
踩
今天分析一下NLP中的pad操作代码:
该方法的作用是将输入的序列列表seqs
进行填充操作,使其具有相同的长度,以便进行批处理。填充使用指定的pad_token
进行,并生成一个对应的mask标志列表,用于标记哪些部分是填充内容(值为1)和哪些部分是原始内容(值为0)。填充后的序列列表和掩盖标志列表将作为方法的返回值,供进一步使用或处理。
- @staticmethod
- def _pad_seqs(seqs, pad_token):
- # 定义变量pad_length,通过遍历seqs token列表获取其中最长token的长度,从而将token列表的所有seq长度都填充到pad_length
- pad_length = max([len(seq) for seq in seqs])
- # 对seqs中的每个token列表进行填充,填充内容为pad_token,填充至长度为pad_length
- padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs]
- # 创建一个mask标志列表,长度为seq的长度,并将前部填充部分置为0(未经过pad),后部未填充部分置为1(经过了pad操作),添加到masks列表中。
- masks = [([0] * len(seq)) + ([1] * (pad_length - len(seq))) for seq in seqs]
- return padded, masks
输入:
tokens, orig_pad_masks = self._pad_seqs(tokens, self.pad_token)
输出:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。