当前位置:   article > 正文

BERT Pytorch版本 源码解析(一)

BERT Pytorch版本 源码解析(一)

BERT Pytorch版本 源码解析(一)

 

一、BERT安装方式

pip install pytorch-pretrained-bert

二、BertPreTrainModel: 

  • 一个用于获取预训练好权重的抽象类,一个用于下载和载入预训练模型的简单接口

1、初始化函数(def __init__(self, config, *inputs, **kwargs)):

  1. def __init__(self, config, *inputs, **kwargs):
  2. super(BertPreTrainedModel, self).__init__()
  3. if not isinstance(config, BertConfig):
  4. raise ValueError(
  5. "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
  6. "To create a model from a Google pretrained model use "
  7. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  8. self.__class__.__name__, self.__class__.__name__
  9. ))
  10. self.config = config

初始化函数主要是用于传入BertConfig的一个对象,这样可以获得Bert模型所需的模型参数,例如hidden_size等

2、最重要的from_pretrained函数: def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)

  1. pretrained_model_name_or_path: either:
  2. - a str with the name of a pre-trained model to load selected in the list of:
  3. . `bert-base-uncased`
  4. . `bert-large-uncased`
  5. . `bert-base-cased`
  6. . `bert-large-cased`
  7. . `bert-base-multilingual-uncased`
  8. . `bert-base-multilingual-cased`
  9. . `bert-base-chinese`
  10. - a path or url to a pretrained model archive containing:
  11. . `bert_config.json` a configuration file for the model
  12. . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
  13. - a path or url to a pretrained model archive containing:
  14. . `bert_config.json` a configuration file for the model
  15. . `model.chkpt` a TensorFlow checkpoint

看一下pretrained_model_name_or_path 这个参数,这个参数可以是两种,一种是你需要下载的预训练的BERT模型类别名称,另一种是你已经下好的BERT预训练模型的路径。

这就是为什么有一些博客上加载预训练模型是直接from_pretrain('bert-base-uncased'),而有一些上面写的是bert模型的路径,这里个人建议是把一些预训练模型下好,然后放到一个固定的文件夹下面可以避免重复下载,每次用的时候直接调用就好了。

 

  1. if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
  2. archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
  3. else:
  4. archive_file = pretrained_model_name_or_path

这部分代码就是解析你传入的pretrained_model_name_or_path参数是一个模型名称还是一个模型路径,首先是进行判断是否是模型名称,不是的话默认为下载好的模型路径。

  1. PRETRAINED_MODEL_ARCHIVE_MAP = {
  2. 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
  3. 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
  4. 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
  5. 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
  6. 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
  7. 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
  8. 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
  9. }

这个就是一个Map,用于将模型名称转换成相对应的URL,所以想预下载的同志们直接在这里copy 一下URL就可以下载了,并不需要去找一下百度云哦,毕竟百度云下载的东西也未必是真的有用(小编就被坑了,很难受 QAQ)。

  1. try:
  2. resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
  3. except EnvironmentError:
  4. logger.error(
  5. "Model name '{}' was not found in model name list ({}). "
  6. "We assumed '{}' was a path or url but couldn't find any file "
  7. "associated to this path or url.".format(
  8. pretrained_model_name_or_path,
  9. ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
  10. archive_file))
  11. return None
  12. '''
  13. cached_path函数的内部实现
  14. '''
  15. def cached_path(url_or_filename, cache_dir=None):
  16. """
  17. Given something that might be a URL (or might be a local path),
  18. determine which. If it's a URL, download the file and cache it, and
  19. return the path to the cached file. If it's already a local path,
  20. make sure the file exists and then return the path.
  21. """
  22. if cache_dir is None:
  23. cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
  24. if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
  25. url_or_filename = str(url_or_filename)
  26. if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
  27. cache_dir = str(cache_dir)
  28. parsed = urlparse(url_or_filename)
  29. if parsed.scheme in ('http', 'https', 's3'):
  30. # URL, so get it from the cache (downloading if necessary)
  31. return get_from_cache(url_or_filename, cache_dir)
  32. elif os.path.exists(url_or_filename):
  33. # File, and it exists.
  34. return url_or_filename
  35. elif parsed.scheme == '':
  36. # File, but it doesn't exist.
  37. raise EnvironmentError("file {} not found".format(url_or_filename))
  38. else:
  39. # Something unknown
  40. raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))

这部分代码就是对于你传入的文件路径或者是转换成的URL做一个处理,如果是URL的话就进行下载的操作,如果是一个本地文件的话就进行文件路径检查以及返回文件路径。

 

三、BertModel

1、BertModel 大概是实战中最应该掌握的模块。初始化函数如下:

  1. def __init__(self, config):
  2. super(BertModel, self).__init__(config)
  3. self.embeddings = BertEmbeddings(config)
  4. self.encoder = BertEncoder(config)
  5. self.pooler = BertPooler(config)
  6. self.apply(self.init_bert_weights)

可以看出初始化 BertModel 的时候是需要传一个config的,这个config就是BertConfig的一个对象,那么我们在项目中要运用一些参数预训练好的模型来进行建模时应该怎么操作呢?

  1. self.bert = BertModel.from_pretrained(model_path)
  2. self.hidden_size = self.bert.config.hidden_size

这个写法是预先下载好了 bert 的预训练模型的写法,将你自己下好的预训练模型的路径传进去就好了,如果没有下载过可以看一下 BertPreTrainModel 部分的解释,建议是将你的bert模型下载好保存到一个固定的文件夹中,以后要用到的时候直接调用就好了,BertModel可以加载几种的预训练模型(包括中文的bert)按需下载就好了。到现在我们就加载好了一个预训练好的bert模型。这种方式去加载的bert模型如果要查看相关的参数配置信息,只需要如上述第二行的方式即可获取。

注意:BERT的参数太大,不要在笔记本上面测试代码能不能跑,很容易就死机了。。。

BertModel的内部运行方式解析:

  1. def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
  2. if attention_mask is None:
  3. attention_mask = torch.ones_like(input_ids)
  4. if token_type_ids is None:
  5. token_type_ids = torch.zeros_like(input_ids)
  6. # We create a 3D attention mask from a 2D tensor mask.
  7. # Sizes are [batch_size, 1, 1, to_seq_length]
  8. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  9. # this attention mask is more simple than the triangular masking of causal attention
  10. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  11. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  12. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  13. # masked positions, this operation will create a tensor which is 0.0 for
  14. # positions we want to attend and -10000.0 for masked positions.
  15. # Since we are adding it to the raw scores before the softmax, this is
  16. # effectively the same as removing these entirely.
  17. extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
  18. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  19. embedding_output = self.embeddings(input_ids, token_type_ids)
  20. encoded_layers = self.encoder(embedding_output,
  21. extended_attention_mask,
  22. output_all_encoded_layers=output_all_encoded_layers)
  23. sequence_output = encoded_layers[-1]
  24. pooled_output = self.pooler(sequence_output)
  25. if not output_all_encoded_layers:
  26. encoded_layers = encoded_layers[-1]
  27. return encoded_layers, pooled_output

首先是参数传入,一般来说对于 基本的BertModel 以及之后的另外一些模型都是传入 input_ids, token_type_ids, attention_mask三个参数,下面解释一下三个参数的含义。

input_ids: 如果是用BertTokenizer 进行分词的,那么会自动生成对应的 tokens to ids 的函数,将你的句子直接扔进这个函数就可以得到一个用词表index描述的句子。用 Batch 的方式去训练的记得将input_ids进行padding操作。

token_type_ids: BertModel 每一次最多允许两个句子输入模型中,所以你的 token_type_ids 只能是0或者1。

attention_mask: Encoder 做 attention的时候需要利用这个部分进行对padding的无用信息进行舍去,不进行attention操作。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号