使用无监督学习获取图像表征向量是一个长期问题,目前已经有大量的自监督学习算法用于获取图像表征向量,但是这些方法和有监督的方式仍然相距甚远。本文介绍一种采用对比学习的框架 SimCLR,可以生成高质量的图像表征,可以取得接近有监督学习算法的效果。

概述

在前一篇文章SimCSE: 通过对比学习获得句子向量中我们介绍了 SimCSE 算法,SimCSE 采用对比学习训练得到句子向量。本文介绍一种利用对比学习生成图像表征的算法 SimCLR,SimCLR 出自 Google 的论文《A Simple Framework for Contrastive Learning of Visual Representations》。

对比学习 (Contrastive Learning) 的目标就是让模型学会区分样本是否相似,因此训练需要同时提供相似样本 (正样本) 和不相似样本 (负样本),如下图所示:

对比学习示意图

SimCLR 训练的数据无需人工标注,对于一幅图像 x,其采用数据增强的方式生成图片 x 的正样本对 (xi, xj),将 batch 里的其他图像当成负样本。然后 SimCLR 使用对比学习训练 Encoder (通常是 CNN 模型,例如 ResNet),从而生成高质量的图像表征。在实验中 SimCLR 取得了 SOTA 的效果,超越了之前的自监督学习算法,并且 top-1 准确率可以逼近有监督的 ResNet-50。

2.SimCLR

SimCLR 结构图

SimCLR 的结构如上图所示,图片出自博客 The Illustrated SimCLR Framework,SimCLR 包含三个部分:

  • 数据增强 Data Augmentation,对图片进行随机的变换 (如裁剪、翻转、颜色抖动等),变换后的数据作为正样本。

  • Encoder,图像编码模型 (如 ResNet、AlexNet 等),SimCLR 使用 Encoder 获得图像表征向量,Encoder 也可用于其他下游任务的微调。

  • 非线性投影层,Projection Head,对 Encoder 输出的表征进行变换,投影层只用于训练 SimCLR,训练结束后使用 Encoder 得到图像表征。

2.1 数据增强

数据增强广泛用在视觉领域,能够增加样本的数量及多样性,使模型更加健壮。图像数据增强的方法多种多样,如下图所示:

数据增强

SimCLR 对图片进行数据增强时不是采用单一的增强方式,而是会随机使用多种不同的增强方法进行结合,这样能够产生更好的表征向量。

作者也通过一个小实验,证明结合不同的增强方法能够产生更好的表征向量。实验采用 ImageNet 数据集,指标为 top-1 准确率,实验结果如下图所示。其中对角线的位置表示采用单一的数据增强方法,其他位置表示两种数据增强方法结合,最后一列表示每一行的平均值。可以看到结合后的效果会大大提升。

不同数据增强方法组合的准确率

SimCLR 会为一个 batch 里的每一幅图像 x 进行两次数据增强,分别得到图像 xi 和 xj,则 (xi, xj) 作为一对正样本,如下图所示:

SimCLR 数据增强

经过数据增强后,我们就可以得到一个 batch 数据的正样本和负样本,如下图所示,SimCLR 需要让正样本的相似度尽可能高,让负样本之间的相似度尽可能低:

正负样本

2.2 非线性投影层

SimCLR 使用 ResNet-50 作为 Encoder,用于获取图像的表征向量 (Representation),同时 Encoder 也可用于后续的下游任务。但是 SimCLR 在训练时为了得到更好的效果,还需要在 Encoder 后增加非线性投影层 (Dense-Relu-Dense),如下图所示,注意非线性投影层只在训练时使用

SimCLR 结构图

作者在原文里对非线性投影层的作用进行了一些解释,认为 Encoder 后的表征 h 包含更多的信息 (例如数据增强变换信息、颜色、方向),而非线性投影层的输出 z 可以去掉这些多余的信息,还原数据本质。Encoder 的输出信息丰富对于下游任务更有帮助,但并不适合对比学习任务,因此用非线性投影层对数据进行还原从而更好地训练。

2.3 损失函数

假设图像 xi 和 xj 经过 SimCLR 的输出为 zi 和 zj,则首先要计算 zi 和 zj 的余弦相似度,如下。

余弦相似度

如果一个 batch 里有 N 个图像,则数据增强后会有 2N 个图像,每一个图像 xi 会有 1 个正样本和 2N-2 个负样本,则对于一对正样本 (zi 和 zj),损失函数如下所示。

损失函数

3.实验效果

下面的两幅图展示了 SimCLR 和其他自监督学习算法的对比,数据集为 ImageNet。可以看到 SimCLR 远超之前的算法,并且可以达到和有监督相近的准确率。

ImageNet 对比图
ImageNet 对比表

下图展示了 SimCLR 在图像分类上进行迁移学习的效果,用到了 12 个图像分类数据集。

SimCLR 迁移学习和有监督学习

4.参考文献

  • SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

    https://arxiv.org/pdf/2002.05709.pdf

  • SimCLR 代码

    https://github.com/google-research/simclr

  • 博客: The Illustrated SimCLR Framework

    https://amitness.com/2020/03/illustrated-simclr/

举报/反馈

NLP学习笔记

322获赞 802粉丝
专注分享自然语言处理知识
关注
0
0
收藏
分享