赞
踩
最近有需要在新的领域进一步训练BERT,因此参照了hugging face官方文档写了相应的代码。本文采用的是hugging face提供的checkpoint,并在相应的task special领域进行了微调。由于项目的保密协议代码数据不便全部公开,下面只给出关键的部分。
重新训练BERT主要是在自己的数据集上实现Masked Language Model的预测任务。我忘记了在哪篇论文里看到Next Sentence Prediction对下游的任务的增益其实并不大(如果有误还请指出),并且本次重新训练是基于短句子语料的,所以只考虑MLM任务。还是用李宏毅老师的PPT中的例子说明MLM的目标:
在给定一个句子,以一定的概率随机mask其中的token(BERT中使用15%的概率),MLM的目标是在整个BERT的词表空间中预测[MASK]的词的概率分布,也就是会产出一个
∣
V
∣
|V|
∣V∣的概率向量,
V
V
V表示词表,经过
s
o
f
t
m
a
x
(
P
∣
V
∣
)
softmax(P_{|V|})
softmax(P∣V∣)之后就可以获取到最可能的预测结果。MLM旨在让BERT通过self-attention熟悉相应的上下文。
BERT中的token表的大小是有限的,如果领域包含词表中未录入的词,则会产生[UNK]。为此,需要对词表进行扩充,如get_new_tokens所示:
class DataLoader: def __init__(self, in_dir='../初始数据.csv', out_dir='checkpoints/policybert', bert_dir='checkpoints/bert', train_source='title', batch_size: int = 64, max_len: int = 64, shuffle: bool = True, mask_token='[MASK]', mask_rate=0.15): super(DataLoader, self).__init__() self.in_dir = in_dir self.out_dir = out_dir self.bert_dir = bert_dir self.train_source = train_source if len(os.listdir(out_dir)) == 0: # 如果没有保存好的checkpoints,那么就使用BERT的tokenizer self.tokenizer = BertTokenizer.from_pretrained(self.bert_dir) else: self.tokenizer = BertTokenizer.from_pretrained(self.out_dir) self.get_data() # 定义BERT数据加载的迭代器 self.bert_iter = BERTMLMDataIter(datas=self.datas, tokenizer=self.tokenizer, max_len=max_len, batch_size=batch_size) self.model = BertForMaskedLM.from_pretrained(self.bert_dir) # 注意在使用之前resize bemedding大小 self.model.resize_token_embeddings(len(self.tokenizer)) def get_data(self): ''' 加载源数据获取raw的文本,文本是以excel(csv)形式存放的,并且只加载'title'字段的文本进行训练 :return: ''' self.datas = [] frame = list(pd.read_csv(self.in_dir)[self.train_source].values) self.datas.extend(frame) def get_new_tokens(self): ''' 为BERT词表添加新的tokens :return: ''' self.new_tokens = [] for data in self.datas: # tokens = self.tokenizer.tokenize(data) # print(tokens) for word in data: if word not in self.tokenizer.vocab: # 由于是中文的模型,因此这里剔除一些非中文的特殊字符 if u'\u4e00' <= word <=u'\u9fff' and word not in self.new_tokens: self.new_tokens.append(word) self.tokenizer.add_tokens(self.new_tokens) self.tokenizer.save_pretrained(self.out_dir) #保存增加的词表
在词表扩充完毕下一次加载的时候,需要注意,因为BERT的第一层是Embedding层,参数量依赖于词表的大小,此时词表已经发生了变化,因此需要对其进行resize,也就是:self.model.resize_token_embeddings(len(self.tokenizer))。
由于数据加载使用的csv中包含了几十万条数据,为了不为难显存,所以我选择了进行数据的动态加载,也就是构建下述迭代器:
class BERTMLMDataIter(): ''' BertForMaskedLM的数据加载工具,其输入的格式为:The capital of France is [MASK]转化之后的ids, 输出则为[Mask]的预测 ''' def __init__(self, datas:list, tokenizer: BertTokenizer, batch_size: int = 32, max_len: int = 128, shuffle:bool=True, mask_token='[MASK]', mask_rate=0.15): super(BERTMLMDataIter).__init__() self.datas = datas self.tokenizer = tokenizer self.batch_size = batch_size self.max_len = max_len self.shuffle = shuffle self.Mask_id = self.tokenizer.convert_tokens_to_ids(mask_token) self.mask_rate = mask_rate # 首次初始化 self.reset() self.ipts = 0 def reset(self): print("dataiter reset, 读取数据") if self.shuffle: random.shuffle(self.datas) self.data_iter = iter(self.datas) def random_mask(self, tokens, rate): ''' :param token: 需要mask的初始字符串 :param rate: :return: ''' mask_tokens, label = [], [] for word in tokens: mmm = random.random() if mmm <= rate: mask_tokens.append(self.Mask_id) label.append(word) else: mask_tokens.append(word) label.append(-100) # -100表示计算损失函数的时候不计算该值 return mask_tokens, label def get_data(self): ''' 获取mask的data数据以及标签的程序 :return: ''' data_ids = [] labels = [] att_masks = [] for data in self.datas: data_id = self.tokenizer.encode(data) masked_data, label = self.random_mask(data_id, self.mask_rate) att_mask = [1]*len(masked_data)+[0]*(self.max_len-len(masked_data)) masked_data = masked_data + [0]*(self.max_len-len(masked_data)) label = label + [-100]*(self.max_len-len(label)) data_ids.append(masked_data) labels.append(label) att_masks.append(att_mask) def get_batch_data(self): '''''' batch_data = [] for i in self.data_iter: batch_data.append(i) if len(batch_data) == self.batch_size: break if len(batch_data) < 1: return None data_ids = [] labels = [] att_masks = [] for data in batch_data: data_id = self.tokenizer.encode(data) masked_data, label = self.random_mask(data_id, self.mask_rate) if len(masked_data) < self.max_len: att_mask = [1] * len(masked_data) + [0] * (self.max_len - len(masked_data)) else: att_mask = [1]*self.max_len masked_data = masked_data[:self.max_len] label = label[:self.max_len] att_mask = att_mask[:self.max_len] masked_data = masked_data + [0] * (self.max_len - len(masked_data)) label = label + [-100] * (self.max_len - len(label)) data_ids.append(masked_data) labels.append(label) att_masks.append(att_mask) batch_ipts = {} batch_ipts['ids'] = torch.LongTensor(data_ids) batch_ipts['mask'] = torch.LongTensor(att_masks) batch_ipts['label'] = torch.LongTensor(labels) return batch_ipts def __iter__(self): return self def __next__(self): if self.ipts is None: self.reset() self.ipts = self.get_batch_data() if self.ipts is None: raise StopIteration else: return self.ipts
get_batch_data用于在训练时候每次处理一个batch的数据并放入模型训练,因此不用预先将full-batch的数据都进行预处理。
每一个batch的数据处理过程中,都采用random_mask进行15%的随机mask。
这里,主要使用了BertForMaskedLM进行预训练。BertForMaskedLM的输入是带有mask的token_ids,其实就是一串id数字,[MASK]对应的id是103。然后是attention_mask,用于告知模型哪些token需要参与到self-attention的计算,那些不需要。以及labels,对应的标签,没有被mask的token对应的标签是-100,表明该位置的未mask词不参与到损失函数的计算过程中。
训练的时候就比较简单了,调用封装好的train方法即可:
def train_my_bert(self): ttt = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(0) CE = torch.nn.CrossEntropyLoss() optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0) best_loss = 10000000 for epoch in tqdm(range(self.epoch)): total_loss = 0.0 # self.loader.bert_iter是BERTMaskedLM专用的数据迭代器 for step, ipt in tqdm(enumerate(self.loader.bert_iter)): # 获取一个batch的训练数据 ipt = {k: v.to(device) for k, v in ipt.items()} out = self.model(input_ids=ipt['ids'], attention_mask=ipt['mask'], labels=ipt['label']) loss = out[0] total_loss += loss.data.item() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) optimizer.step() self.model.zero_grad() if step % 10 == 0: print('current batch loss:{}'.format(loss.data.item())) if total_loss < best_loss: # 保存最优模型(loss最小的模型),注意为了方便后续的进一步使用,不要使用torch.save保存,而是使用transfomers提供的内置保存方法 print('save best model to {}!!!'.format(self.out_dir)) best_loss = total_loss self.model.save_pretrained(self.out_dir_{}'.format(ttt))
值得注意的是BERTMaskedLM的返回结果中就就包含loss了,因此自定义的损失函数没有用上,具体可以参见BERTMaskedLM的相关文档,写得十分清晰。稍微看一下训练的log输出,其实训练的速度还是挺快的(但是架不住数据集太大):
在上述训练过程结束之后,如果还是需要使用Bert进行下游任务(分类啥的),BertForMaskedLM的参数组成与Bert不同,因此仍需要相应的trick进行参数的迁移。一个简单直接的方法就是将BertForMaskedLM的Bert层的参数赋值给Bert。这个用语言描述可能有点干巴巴,我们还是看一下两个不同模型的每一层到底有啥差异。首先是我们都很属性的Bert:
BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) )
然后是BertForMaskedLM的:
BertForMaskedLM( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) ) (cls): BertOnlyMLMHead( (predictions): BertLMPredictionHead( (transform): BertPredictionHeadTransform( (dense): Linear(in_features=768, out_features=768, bias=True) (transform_act_fn): GELUActivation() (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (decoder): Linear(in_features=768, out_features=30522, bias=True) ) ) )
相比之下,Bert多一个Pooler,少一个BertOnlyMLMHead,其实就是标签空间不同造成的分类的差异。中间的BertLayer以及最开始的Embedding都可以复用。因此,直接将BertForMaskedLM的值给Bert即可:
self.bert.embeddings = self.mlm_bert.bert.embeddings
self.bert.encoder = self.mlm_bert.bert.encoder
然后,就可以使用self.bert去进行下游的工作了。然后保存的时候别忘了保存Bert的model,这样下次直接加载Bert的即可。
训练BERT还是需要挺大量的数据集的,目前我们的工作中对下游任务进行re-train之后的参数效果是否更好,还有待测试。本文只是提供一个相应的思路。如果在小规模数据集上做微调,那么还是推荐使用主任务+MLM辅助任务的形式,让BERT更适配于当前的任务。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。