赞
踩
从基础的nn.Embedding说起:
CLASS torch.nn.Embedding(num_embeddings, embedding_dim,
padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None)
num_embeddings, embedding_dim没啥好说的,就是look-up表的形状,我们在搭建网络时大多情况下只用得上这两个参数。下面具体看看剩下的参数能做什么:
好了,明白了Embedding的参数,再来看EmbeddingBag:
CLASS torch.nn.EmbeddingBag(num_embeddings, embedding_dim,
max_norm=None, norm_type=2.0, scale_grad_by_freq=False, mode=‘mean’, sparse=False, _weight=None)
官方API: https://pytorch.org/docs/stable/nn.htmlhighlight=embeddingbag#torch.nn.EmbeddingBag
参数只多了一个:mode,先来看这个参数的含义。
官网上说得很清楚,取值分三种,对应三种操作:"sum"表示普通embedding后接torch.sum(dim=0),"mean"相当于后接torch.mean(dim=0),"max"相当于后接torch.max(dim=0)。
只看这个参数就清楚了,EmbeddingBag就是把look-up表整合成一个embedding,当不需要具体查表获得embedding,只需要一个整合结果时,它比上述两阶段操作更高效。
来看它的输入:
input (LongTensor)和offsets (LongTensor, optional)
input可以是2D或1D:
说到这可以发现其实和类名一样,这就是个“词袋”操作,典型的应用场景是FastText,多个文档平铺成1D输入,再指定offsets,直接就可以进行批量不等长文档处理,写起来简单,效率又有提升。
官方的例子:
per_sample_weights(Tensor, optional)
该输入给每个实例一个权重再加权求和(此时mode只能为sum),与输入shape相同。
一个典型的应用场景是deepFM,某列特征对应的embedding有时需按照权重加和。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。