Info NCE

发布于 2022-12-05  789 次阅读


概述

其中是温度超参数。总数是一个正样本和个负样本的总和。直观地说,这个损失是基于 softmax的分类器的对数损失,它试图将分类为

设计目的是使当query 和唯一的正样本相似,并且和其他所有负样本key都不相似的时候,这个loss的值应该比较低。

从softmax说起

对于多分类问题,设类别标签可以有个取值。给定一个样本,Softmax回归预测的属于类别的条件概率为

其中是第类的权重向量。

在数学,尤其是概率论和相关领域中,归一化指数函数,或称Softmax函数,是逻辑函数的一种推广。它能将一个含任意实数的维向量“压缩”到另一个维实向量中,使得每一个元素的范围都在之间,并且所有元素的和为1,也就是说softmax保证了所有预测值在0-1之间且和为1。该函数多用于多分类问题中。

交叉熵

因为softmax本身不带有可学习的参数,因此在深度学习框架中通常和一些激活函数被同等的看待。模型结构上,多分类问题与二分类问题在预测头(最后一层全连接)的差异通常在其激活函数是使用softmax还是sigmoid。而在损失函数上,二分类问题常用二元交叉熵函数,多分类问题常用交叉熵函数。

交叉熵损失函数(Cross-Entropy Loss Function)一般用于分类问题.假设样本的标签为离散的类别,模型的输出为类别标签的条件概率分布,即

其中的即是最外层为softmax函数激活的模型本身输出的预测值的数学表示,为ground truth。

将softmax带入交叉熵,得

由于ground truth 中只会有一个分量为1,其余为0,假设实际分类为,则

可以看到已经和Info NCE的公式大体接近。

NCE

如果将对比学习看成多分类问题 正样本和所有的负样本都看成不同的类会导致有上百万的类,导致计算问题。NCE(noise contrastive estimation)核心思想是将多分类问题转化成二分类问题,一个类是数据类别 data sample,另一个类是噪声类别 noisy sample,通过学习数据样本和噪声样本之间的区别,将数据样本去和噪声样本做对比,也就是“噪声对比(noise contrastive)”,从而发现数据中的一些特性。但是,如果把整个数据集剩下的数据都当作负样本(即噪声样本),虽然解决了类别多的问题,计算复杂度还是没有降下来,解决办法就是做负样本采样来计算loss,这就是estimation的含义,也就是说它只是估计和近似。一般来说,负样本选取的越多,就越接近整个数据集,效果自然会更好。

Info NCE

Info NCE loss是NCE的一个简单变体,它认为如果你只把问题看作是一个二分类,只有数据样本和噪声样本的话,可能对模型学习不友好,因为很多噪声样本可能本就不是一个类,因此还是把它看成一个多分类问题比较合理(但这里的多分类k指代的是负采样之后负样本的数量,下面会解释)。于是就有了InfoNCE loss,公式如下:

上式中,是模型出来的logits(指某层的输出值),是一个温度超参数,是个标量,假设忽略,那么infoNCE loss其实就是cross entropy loss。唯一的区别是,在cross entropy loss里,指代的是数据集里类别的数量,而在对比学习InfoNCE loss里,这个指的是负样本的数量(在MoCo中即是负样本队列的长度,也就是字典长度)。上式分母中的sum是在1个正样本和k个负样本上做的,从0到,所以共个样本,也就是字典里所有的key。在MoCo中,InfoNCE loss其实就是一个cross entropy loss,实际实现也是用cross entropy loss,做的是一个类的分类任务,目的就是想把这个图片分到(正样本)这个类。

温度系数的作用

温度系数虽然只是一个超参数,但它的设置是非常讲究的,直接影响了模型的效果。 上式Info NCE loss中的相当于是logits,温度系数可以用来控制logits的分布形状。对于既定的logits分布的形状,当值变大,则就变小,则会使得原来logits分布里的数值都变小,且经过指数运算之后,就变得更小了,导致原来的logits分布变得更平滑。相反,如果取得值小,就变大,原来的logits分布里的数值就相应的变大,经过指数运算之后,就变得更大,使得这个分布变得更集中,更peak。

如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

总之,温度系数的作用就是它控制了模型对负样本的区分度。

联系本文

image-20221130192135261

运用在《Contrastive Code Representation Learning》中也没什么不同,仅仅只是正负样本的生成变化,以及从图像到代码,对比学习中数据预处理和具体编码器结构发生了变化。该论文使用使用Info NCE作为损失函数来优化,试图使经过数据增强后的代码与原代码的表征向量更为接近,而与其他不同语义的代码的表征向量更为疏远,使用对比学习为代码学习到了一个泛化性较好且带有语义信息的稳健分布式表示(特征向量)。

参考

[1] Momentum Contrast for Unsupervised Visual Representation Learning.

[2] https://www.bilibili.com/video/BV1C3411s7t9


面向ACG编程