当前位置:   article > 正文

bert层次位置编码_位置编码层

位置编码层

今天学习了bert之中的层次位置编码,感觉可以很好地用到maxlen超出512的部分苏神的层次位置编码
层次位置编码的方案公式 q i ∗ n + j = α u i + ( 1 − α ) u j q_{i*n+j} = \alpha u_{i}+(1-\alpha)u_{j} qin+j=αui+(1α)uj
这里的初始alpha最好设定为0.4,也就是说0~511的位置编码不变,从第512的时候,对应的坐标为(1,0),计算公式为512/512 = 1,512%512 = 0,接下来计算公式
q 512 = 0.4 ∗ u 1 + 0.6 ∗ u 0 q_{512} = 0.4*u_{1}+0.6*u_{0} q512=0.4u1+0.6u0,同理 q 513 = 0.4 ∗ u 1 + 0.6 ∗ u 1 q_{513} = 0.4*u_{1}+0.6*u_{1} q513=0.4u1+0.6u1,后面的内容以此类推,
使用pythonicforbert调用层次位置编码的代码

maxlen = 500
batch_size = 4
origin_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = maxlen
config.with_mlm = False
if maxlen > origin_max_position_embeddings:
    model_data = model.state_dict()
    roberta = Roberta(config)
    roberta = get_data(roberta,'/home/xiaoguzai/模型/hated-roberta/pytorch_model.bin')
    model = ClassificationModel(roberta,config,1)
    current_position_embedding = model_data['model.robertaembeddings.position_embeddings_layer.weight']
    new_position_embedding = current_position_embedding
    for currentindex in range(origin_max_position_embeddings,maxlen):
        index1 = currentindex//origin_max_position_embeddings-1
        index2 = currentindex%origin_max_position_embeddings
        embedding_data = 0.4*new_position_embedding[index1]+0.6*new_position_embedding[index2]
        embedding_data = torch.tensor([embedding_data.tolist()])
        new_position_embedding = torch.cat((new_position_embedding,embedding_data),0)
    print('new_position_embedding = ')
    print(new_position_embedding.size())
    model_data['model.robertaembeddings.position_embeddings_layer.weight'] = new_position_embedding
    model.load_state_dict(model_data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

有的时候maxlen更长效果不好并不一定是位置编码的原因,而可能是因为学到了更多无用的信息。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/348663
推荐阅读
相关标签
  

闽ICP备14008679号