赞
踩
""" @Time : 2022/10/22 @Author : Peinuan qin """ import torch import torch.nn as nn from transformers import BertTokenizer,BertModel, BertConfig from Dataset import BERTDataset class PositionalEmbedding(nn.Embedding): def __init__(self, d_model, max_len=512): super(PositionalEmbedding, self).__init__(max_len, d_model, padding_idx=0) class SegmentEmbedding(nn.Embedding): def __init__(self, d_model, segement_num=2): super(SegmentEmbedding, self).__init__(segement_num, d_model, padding_idx=0) class TokenEmbedding(nn.Embedding): def __init__(self, d_model, vocab_size,): super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=0) class BERTEmbedding(nn.Module): def __init__(self, vocab_size, d_model, drop_rate=0.1): super(BERTEmbedding, self).__init__() # super(BERTEmbedding, self).__init__() self.token_embedding = TokenEmbedding(d_model, vocab_size) self.position_embedding = PositionalEmbedding(d_model) self.segment_embedding = SegmentEmbedding(d_model) self.dropout = nn.Dropout(drop_rate) def forward(self, sequence, segment_labels, position_ids): x = self.token_embedding(sequence) + self.segment_embedding(segment_labels) + self.position_embedding(position_ids) return self.dropout(x) if __name__ == '__main__': model_name = '../bert_pretrain_base/' d_model = 768 config = BertConfig.from_pretrained(model_name) tokenizer = BertTokenizer.from_pretrained(model_name) model = BertModel.from_pretrained(model_name) bert_embedding = BERTEmbedding(vocab_size=config.vocab_size, d_model=d_model) dataset = BERTDataset(corpus_path="./corpus_chinese.txt" , tokenizer=tokenizer , seq_len=20) sample = dataset[0] input_ids = sample['input_ids'] segment_labels = sample['segment_labels'] # 是固定的 [0, ..., seq_len-1] position_ids = torch.tensor([i for i in range(len(input_ids))]) print(sample) x = bert_embedding(input_ids, segment_labels, position_ids) print(x)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。