当前位置:   article > 正文

自然语言推断:微调BERT_自然语言推断:微调bert

自然语言推断:微调bert

将下载一个预训练好的小版本的BERT,然后对其进行微调,以便在SNLI数据集上进行自然语言推断。

  1. import json
  2. import multiprocessing
  3. import os
  4. from mxnet import gluon, np, npx
  5. from mxnet.gluon import nn
  6. from d2l import mxnet as d2l
  7. npx.set_np()

加载预训练的BERT

在下面,我们提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

  1. d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
  2. '7b3820b35da691042e5d34c0971ac3edbd80d3f4')
  3. d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
  4. 'a4e718a47137ccd1809c9107ab4f5edd317bae2c')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

  1. def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
  2. num_heads, num_layers, dropout, max_len, devices):
  3. data_dir = d2l.download_extract(pretrained_model)
  4. # 定义空词表以加载预定义词表
  5. vocab = d2l.Vocab()
  6. vocab.idx_to_token = json.load(open(os.path.join(data_dir,
  7. 'vocab.json')))
  8. vocab.token_t
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/367406
推荐阅读
相关标签
  

闽ICP备14008679号