赞
踩
import torch.nn as nn from transformers.modeling_bert import BertLayer,BertConfig from lebert import WordEmbeddingAdapter from transformers.modeling_bert import BertEmbeddings,BertConfig,BertPooler,BertLayer,BaseModelOutput,BaseModelOutputWithPooling from lebert import BertEncoder import torch class BertEncoder(nn.Module): def __init__(self,config): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.word_embedding_adapter = WordEmbeddingAdapter(config) def forward( self, hidden_states, word_embeddings, word_mask, attention_mask = None, head_mask = None, encoder_hidden_states = None, encoder_attention_mask = None, output_attentions = False, output_hidden_states = False, return_dict = False ): # hidden_states:[batch_size,max_len,bert_dim] # word_embeddings:[batch_size,max_len,num_words,word_dim] # word_mask:[batch_size,max_len,num_words] # attention_mask:[batch_size,1,1,max_len] # head_mask:[num_hidden_layers] # encoder_hidden_states,encoder_attention_mask:None # output_attentions,output_hidden_states,return_dict:None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i,layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if getattr(self.config,'gradient_checkpointing',False): def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs,output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask ) else: # hidden_states:[batch_size,max_len,bert_dim] # attention_mask:[batch_size,1,1,max_len] # head_mask[i]:None # encoder_hidden_states,encoder_attention_mask:None # output_attentions:False # layer_outputs:([batch_size,max_len,bert_dim],) layer_outputs = layer_module( hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, output_attentions ) # hidden_states:[batch_size,max_len,bert_dim] hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) # 在指定的层进行融合 if i == self.config.add_layer: # hidden_states:[batch_size,max_len,bert_dim] hidden_states = self.word_embedding_adapter(hidden_states,word_embeddings,word_mask.byte()) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states,all_hidden_states,all_attentions] if v is not None) return BaseModelOutput( last_hidden_state = hidden_states, hidden_state = all_hidden_states, attentions = all_attentions ) pretrain_model_path = 'bert-base-chinese' config = BertConfig.from_pretrained(pretrain_model_path) config.word_embed_dim = 200 config.add_layer = 0 num_words = 3 max_len = 10 bert_dim = 768 batch_size = 4 hidden_states = torch.randn([batch_size,max_len,bert_dim]) word_embeddings = torch.randn([batch_size,max_len,num_words,config.word_embed_dim]) word_mask = torch.ones([batch_size,max_len,num_words]).long() attention_mask = torch.ones([batch_size,1,1,max_len]).byte() head_mask = [None] * config.num_hidden_layers model = BertEncoder(config) outputs = model( hidden_states = hidden_states, word_embeddings = word_embeddings, word_mask = word_mask, attention_mask = attention_mask, head_mask = head_mask, encoder_hidden_states = None, encoder_attention_mask = None, output_attentions = False, output_hidden_states = False, return_dict = False) print('outputs[0].shape:',outputs[0].shape)
outputs[0].shape: torch.Size([4, 10, 768])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。