当前位置:   article > 正文

text_matching-sentence_transformers 双塔结构_使用transformer实现双塔

使用transformer实现双塔

3层RoBERTa效果(RBT3)

class SentenceTransformer(nn.Layer):
    def __init__(self, pretrained_model, dropout=None):
        super().__init__()
        self.ptm = pretrained_model
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
        # num_labels = 2 (similar or dissimilar)
        self.classifier = nn.Linear(self.ptm.config["hidden_size"] * 3, 2)

    def forward(self,
                query_input_ids,
                title_input_ids,
                query_token_type_ids=None,
                query_position_ids=None,
                query_attention_mask=None,
                title_token_type_ids=None,
                title_position_ids=None,
                title_attention_mask=None):
        query_token_embedding, _ = self.ptm(
            query_input_ids, query_token_type_ids, query_position_ids,
            query_attention_mask)
        query_token_embedding = self.dropout(query_token_embedding)
        query_attention_mask = paddle.unsqueeze(
            (query_input_ids != self.ptm.pad_token_id
             ).astype(self.ptm.pooler.dense.weight.dtype),
            axis=2)
        # Set token embeddings to 0 for padding tokens
        query_token_embedding = query_token_embedding * query_attention_mask
        query_sum_embedding = paddle.sum(query_token_embedding, axis=1)
        query_sum_mask = paddle.sum(query_attention_mask, axis=1)
        query_mean = query_sum_embedding / query_sum_mask

        title_token_embedding, _ = self.ptm(
            title_input_ids, title_token_type_ids, title_position_ids,
            title_attention_mask)
        title_token_embedding = self.dropout(title_token_embedding)
        title_attention_mask = paddle.unsqueeze(
            (title_input_ids != self.ptm.pad_token_id
             ).astype(self.ptm.pooler.dense.weight.dtype),
            axis=2)
        # Set token embeddings to 0 for padding tokens
        title_token_embedding = title_token_embedding * title_attention_mask
        title_sum_embedding = paddle.sum(title_token_embedding, axis=1)
        title_sum_mask = paddle.sum(title_attention_mask, axis=1)
        title_mean = title_sum_embedding / title_sum_mask

        sub = paddle.abs(paddle.subtract(query_mean, title_mean))
        projection = paddle.concat([query_mean, title_mean, sub], axis=-1)

        logits = self.classifier(projection)
        probs = F.softmax(logits)

        return probs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

双塔模型结构构造

Sentence Transformer采用了双塔(Siamese)的网络结构。Query和Title分别输入ERNIE,共享一个ERNIE参数,得到各自的token embedding特征。之后对token embedding进行pooling(此处教程使用mean pooling操作),之后输出分别记作u,v。之后将三个表征(u,v,|u-v|)拼接起来,进行二分类。网络结构如上图所示。
那么Sentence Transformer采用Siamese的网路结构,是如何提升预测速度呢?

Siamese的网络结构好处在于query和title分别输入同一套网络。如在信息搜索任务中,此时就可以将数据库中的title文本提前计算好对应sequence_output特征,保存在数据库中。当用户搜索query时,只需计算query的sequence_output特征与保存在数据库中的title sequence_output特征,通过一个简单的mean_pooling和全连接层进行二分类即可。从而大幅提升预测效率,同时也保障了模型性能。
在这里插入图片描述
实际场景中,只要改变下模型的输出,就可以提前把所以的待检索文本向量话,然后进行存储。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/464434
推荐阅读
相关标签
  

闽ICP备14008679号