赞
踩
3层RoBERTa效果(RBT3)
class SentenceTransformer(nn.Layer): def __init__(self, pretrained_model, dropout=None): super().__init__() self.ptm = pretrained_model self.dropout = nn.Dropout(dropout if dropout is not None else 0.1) # num_labels = 2 (similar or dissimilar) self.classifier = nn.Linear(self.ptm.config["hidden_size"] * 3, 2) def forward(self, query_input_ids, title_input_ids, query_token_type_ids=None, query_position_ids=None, query_attention_mask=None, title_token_type_ids=None, title_position_ids=None, title_attention_mask=None): query_token_embedding, _ = self.ptm( query_input_ids, query_token_type_ids, query_position_ids, query_attention_mask) query_token_embedding = self.dropout(query_token_embedding) query_attention_mask = paddle.unsqueeze( (query_input_ids != self.ptm.pad_token_id ).astype(self.ptm.pooler.dense.weight.dtype), axis=2) # Set token embeddings to 0 for padding tokens query_token_embedding = query_token_embedding * query_attention_mask query_sum_embedding = paddle.sum(query_token_embedding, axis=1) query_sum_mask = paddle.sum(query_attention_mask, axis=1) query_mean = query_sum_embedding / query_sum_mask title_token_embedding, _ = self.ptm( title_input_ids, title_token_type_ids, title_position_ids, title_attention_mask) title_token_embedding = self.dropout(title_token_embedding) title_attention_mask = paddle.unsqueeze( (title_input_ids != self.ptm.pad_token_id ).astype(self.ptm.pooler.dense.weight.dtype), axis=2) # Set token embeddings to 0 for padding tokens title_token_embedding = title_token_embedding * title_attention_mask title_sum_embedding = paddle.sum(title_token_embedding, axis=1) title_sum_mask = paddle.sum(title_attention_mask, axis=1) title_mean = title_sum_embedding / title_sum_mask sub = paddle.abs(paddle.subtract(query_mean, title_mean)) projection = paddle.concat([query_mean, title_mean, sub], axis=-1) logits = self.classifier(projection) probs = F.softmax(logits) return probs
Sentence Transformer采用了双塔(Siamese)的网络结构。Query和Title分别输入ERNIE,共享一个ERNIE参数,得到各自的token embedding特征。之后对token embedding进行pooling(此处教程使用mean pooling操作),之后输出分别记作u,v。之后将三个表征(u,v,|u-v|)拼接起来,进行二分类。网络结构如上图所示。
那么Sentence Transformer采用Siamese的网路结构,是如何提升预测速度呢?
Siamese的网络结构好处在于query和title分别输入同一套网络。如在信息搜索任务中,此时就可以将数据库中的title文本提前计算好对应sequence_output特征,保存在数据库中。当用户搜索query时,只需计算query的sequence_output特征与保存在数据库中的title sequence_output特征,通过一个简单的mean_pooling和全连接层进行二分类即可。从而大幅提升预测效率,同时也保障了模型性能。
实际场景中,只要改变下模型的输出,就可以提前把所以的待检索文本向量话,然后进行存储。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。