当前位置:   article > 正文

【自然语言处理】【对比学习】搞nlp还不懂对比学习,不会吧?快来了解下SimCLR_对比学习 nlp

对比学习 nlp

近期对比学习 NLP \text{NLP} NLP领域取得了不错的成绩,例如句嵌入方法 SimCSE [ 1 ] \text{SimCSE}^{[1]} SimCSE[1]和短文本聚类方法 SCCL [ 2 ] \text{SCCL}^{[2]} SCCL[2]。为了能更好的理解近期的进展,期望通过一系列相关的文章来循序渐进的介绍其中的技术和概念。本文就作为该系列的第一篇文章吧~

一、 SimCLR \text{SimCLR} SimCLR简介

二、 SimCLR \text{SimCLR} SimCLR框架

在这里插入图片描述

图1. SimCLR框架

SimCLR \text{SimCLR} SimCLR是一个对比学习的框架,其结构如图1所示,主要包含四个组件:

1. 数据增强模块

该模块会为一个样本随机生成两个增强样本 x ~ i \tilde{x}_i x~i x ~ j \tilde{x}_j x~j,这两个样本组成了一个正样本对 ( x ~ i , x ~ j ) (\tilde{x}_i,\tilde{x}_j) (x~i,x~j)

  • 论文主要是针对图像的。因此,采用的数据增强方式包括:裁剪、颜色失真、高斯模糊;
2. 编码器

编码器 f ( ⋅ ) f(\cdot) f()的作用是将增强样本转换为向量表示, h i = f ( x ~ i ) \textbf{h}_i=f(\tilde{x}_i) hi=f(x~i)

  • 论文选择 ResNet \text{ResNet} ResNet作为编码器, h i = f ( x ~ i ) = ResNet ( x ~ i ) \textbf{h}_i=f(\tilde{x}_i)=\text{ResNet}(\tilde{x}_i) hi=f(x~i)=ResNet(x~i)
3. 投影头(Projection head)

投影头 g ( ⋅ ) g(\cdot) g()是一个小型神经网络,其作用是将样本的向量表示映射至可以对比的空间中(也就是适合Loss计算的表示空间);

  • 论文使用单层全连接神经网络作为投影头,即 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) z_i=g(\textbf{h}_i)=W^{(2)}\sigma(W^{(1)}\textbf{h}_i) zi=g(hi)=W(2)σ(W(1)hi) σ \sigma σ ReLU \text{ReLU} ReLU激活函数;
4. 对比损失函数

对比损失函数 l \mathcal{l} l,其作用是:在一个包含正样本对 ( x ~ i , x ~ j ) (\tilde{x}_i,\tilde{x}_j) (x~i,x~j)的集合 { x ~ k } \{\tilde{x}_k\} {x~k},给定样本 x ~ i \tilde{x}_i x~i,从 { x ~ k } k ≠ i \{\tilde{x}_k\}_{k\neq i} {x~k}k=i中确定出 x ~ j \tilde{x}_j x~j

三、框架的实现

上面描述了 SimCLR \text{SimCLR} SimCLR框架,本小节则是该框架的一个具体实现。

1. 损失函数 NT-Xent \text{NT-Xent} NT-Xent
  • 随机采样 N N N个样本作为 minibatch \text{minibatch} minibatch,并通过数据增强生成 2 N 2N 2N个样本。这里将正样本对以外 2 ( N − 1 ) 2(N-1) 2(N1)个样本当做负样本;
  • 向量相似度计算方式为: sim ( u , v ) = u ⊤ v / ∥ u ∥ ∥ v ∥ \text{sim}(u,v)=u^\top v/\Vert u\Vert\Vert v\Vert sim(u,v)=uv/uv
  • 正样本对 ( i , j ) (i,j) (i,j)的损失函数

l i , j = − log exp(sim( z i , z j ) / τ ) ∑ k = 1 2 N 1 k ≠ i exp(sim( z i , z k ) / τ ) \mathcal{l}_{i,j} = -\text{log}\frac{\text{exp(sim(}z_i,z_j)/\tau)}{\sum_{k=1}^{2N}1_{k\neq i}\text{exp(sim(}z_i,z_k)/\tau)} li,j=logk=12N1k=iexp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
​ 其中, 1 [ k ≠ i ] ∈ { 0 , 1 } 1_{[k\neq i]}\in\{0,1\} 1[k=i]{0,1}是指示函数, τ \tau τ是温度(temperature)参数;

  • 同一个 minibatch \text{minibatch} minibatch中所有正样本对的损失之和为最终的loss,称这个loss为 NT-Xent \text{NT-Xent} NT-Xent
2. 完整的算法描述

输入:batch size N N N,常量 τ \tau τ,结构 f , g , T f,g,\mathcal{T} f,g,T

for 采样的minibatch { x k } k = 1 N \{x_k\}_{k=1}^N {xk}k=1N do

​  for all k ∈ { 1 , … , N } k\in\{1,\dots,N\} k{1,,N} do

​ 随机选择两种数据增强函数 t ∼ T , t ′ ∼ T t\sim\mathcal{T},t'\sim\mathcal{T} tT,tT

​​ # 第一个数据增强

​​  x ~ 2 k − 1 = t ( x k ) \tilde{x}_{2k-1}=t(x_k) x~2k1=t(xk)

​​  h 2 k − 1 = f ( x ~ 2 k − 1 ) h_{2k-1}=f(\tilde{x}_{2k-1}) h2k1=f(x~2k1) # 表示

​​  z 2 k − 1 = g ( h 2 k − 1 ) z_{2k-1}=g(h_{2k-1}) z2k1=g(h2k1) # 投影

​​ # 第二个数据增强

​​  x ~ 2 k = t ′ ( x k ) \tilde{x}_{2k}=t'(x_k) x~2k=t(xk)

​​  h 2 k = f ( x ~ 2 k − 1 ) h_{2k}=f(\tilde{x}_{2k-1}) h2k=f(x~2k1) # 表示

​ ​  z 2 k = g ( h 2 k − 1 ) z_{2k}=g(h_{2k-1}) z2k=g(h2k1) # 投影

​  end for

  for all i ∈ { 1 , … , 2 N }  and  j ∈ { 1 , … , 2 N } i\in\{1,\dots,2N\}\text{ and } j\in\{1,\dots,2N\} i{1,,2N} and j{1,,2N} do

​   s i , j = z i z j / ( ∣ ∣ z i ∣ ∣ ∣ ∣ z j ∣ ∣ ) s_{i,j}=z_iz_j/(||z_i||||z_j||) si,j=zizj/(zizj)

​  end for

​  定义 l ( i , j ) \mathcal{l}(i,j) l(i,j) l i , j = − log exp(sim( z i , z j ) / τ ) ∑ k = 1 2 N 1 k ≠ i exp(sim( z i , z k ) / τ ) \mathcal{l}_{i,j} = -\text{log}\frac{\text{exp(sim(}z_i,z_j)/\tau)}{\sum_{k=1}^{2N}1_{k\neq i}\text{exp(sim(}z_i,z_k)/\tau)} li,j=logk=12N1k=iexp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)

​   L = 1 2 N ∑ k = 1 N [ l ( 2 k − 1 , 2 k ) + l ( 2 k , 2 k − 1 ) ] \mathcal{L}=\frac{1}{2N}\sum_{k=1}^N[\mathcal{l}(2k-1,2k)+\mathcal{l}(2k,2k-1)] L=2N1k=1N[l(2k1,2k)+l(2k,2k1)]

​  通过最小化 L \mathcal{L} L来更新网络 f f f g g g

end for

return 返回编码网络 f ( ⋅ ) f(\cdot) f(),并丢弃 g ( ⋅ ) g(\cdot) g()

3. 训练细节
  • 为了不使用memory bank,将batch size从256增大至8192;
  • 由于 SGD \text{SGD} SGD在大batch size上不稳定,使用LARS进行训练;

四、分析

  • 数据增强操作的组合对于学习好的向量表示至关重要
    在这里插入图片描述

    上图是不同种数据增强方式间组合带来的影响,对角线表示单个一种数据增强方法。可以发现,对角线的颜色都比较深,也就是说单一的数据增强方式效果并不好。两两组合的数据增强方式效果更佳。

  • 相较于有监督学习,数据增强对对比学习更加有效
    在这里插入图片描述

    上表时数据增强程度对有监督学习(Supervised)和对比学习(SimCLR)的影响。可以发现,数据增强对“对比学习”影响更大。

  • 模型越大、对比学习效果越好

在这里插入图片描述

上图中红色的点是对比学习的效果,随着模型规模的增大,效果也越来越好;

  • 非线性投影头能改善向量表示的质量

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vioc8xlZ-1622874702331)(\图片\投影头.png)]

    上图中,非线性投影头优于线性投影头,线性投影头优于不进行投影;

  • 合适的温度参数能够帮助模型学习到更难的负样本

在这里插入图片描述

观察上表, l2 norm \text{l2 norm} l2 norm是有效的,而是适当大小的 τ \tau τ也有助于模型的表现;

  • 大batch size和长的训练时间也有益于对比学习
    在这里插入图片描述

    观察上图,大的batch size和较大的epoch有助于模型的表现;

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/595923
推荐阅读
相关标签
  

闽ICP备14008679号