导读
自从Transformer开始用于视觉任务后,相比于CNN来说虽然是另辟蹊径,但一直以来Transformer在视觉领域的应用比起NLP来总感觉差那么点意思。综合分析下来主要有两个问题:1.NLP领域的scale相对固定,而视觉领域的scale变化范围非常大。2.视觉应用需要大的图像分辨率,视觉Transformer计算的复杂度是图像分辨率的平方,因而计算量会变得异常庞大。针对ViT模型存在的上述问题,微软亚洲研究院提出了Swin Transformer这样一种新的视觉Transformer结构。Swin Transformer相较于ViT模型,做了两个相当大的创新改进:1.通过类似于CNN的层次化 (hierarchical) 方式来构建Transformer层。2.引入locality和windows设计,按windows计算self-attention。
提出Swin Transformer的这篇论文为Swin Transformer: Hierarchical Vision Transformer using Shifted Windows,由微软亚洲研究院提出,发表于2021年3月份。作为能够取代ViT,成为视觉Transformer领域新的backbone的Swin Transformer,本文力求对其进行一个相对完整和详细的解读。源码地址:https://github.com/microsoft/Swin-Transformer
Swin Transformer整体结构
Swin Transformer模型整体结构如下图所示:
Swin Transformer
可以看到,Swin Transformer由四个类似的stage构成,每个stage里面都有若干个Swin Transformer block。假设输入图像维度为HxWx3,按照ViT的基本思路,先对其进行分块 (patch partition) 处理,每个patch的大小为4x4,flatten之后的patch特征维度为4x4x3=48,patch的数量为H/4 x W/4。在stage1阶段,先对图像patch做一个线性patch embedding,然后是两个连续的Swin Transformer block,特征维度由48变为C,输出特征图大小为H/4 x W/4 x C。
然后是stage2阶段,根据Swin Transformer层次化设计的思想,stage2中先用了一个patch mergeing操作,将输入特征图按照2x2的相邻的patches进行合并,再经过两个连续的Swin Transformer block,输出维度就变成了H/8 x W/8 x 4C,然后用一个线性embedding将4C压缩为2C,所以stage2最后的输出维度为H/8 x W/8 x 2C。同理stage3和stage4继续这种patch embedding和Swin Transformer block的操作,只不过使用的Swin Transformer block数量有所区别,stage3用了6个block,stage4用了2个。最后的输出为H/32 x W/32 x 8C。下图为Swin Transformer的层次化设计:
Swin Transformer层次化设计
相较于ViT随着网络加深patch数量保持不变,Swin Transformer在网络加深过程中,patch数量会随着层次化的网络搭建特征而逐渐减少。
Swin Transformer block
通过Swin Transformer的整体架构我们可知,Swin Transformer block是整个架构的核心结构,两个连续的Swin Transformer block如下图所示:
如上图所示,第一个Swin Transformer block,由LayerNorm、W-MSA和MLP组成,W-MSA即Windows MSA,可以理解为分窗的多头注意力。第二个Swin Transformer block跟第一个block基本一样,只是将W-MSA换成了SW-MSA,即Shifted Windows,可以理解为移位的窗口多头注意力。Shifted Windows也正是Swin Transformer的命名由来。
基于自注意力的Shifted Window
由上图可以看到Swin Transformer block的关键在于W-MSA和SW-MSA。我们先梳理一下其中的变换逻辑。首先是为什么要做分窗 (Window) 处理?由前面描述可知分窗一方面是为了构建层次化的Transformer层设计,另一方面则是为了减少计算量,Window based attention要比Image based atttention节约很多计算资源。这是常规的分窗的好处?那么这种分窗操作是否有问题?当然有。一个最主要的问题就是不同窗口之间缺乏联系,这会对Transformer的建模能力产生较大影响。
为了让不重叠的Window之间产生联系,Swin Transformer提出了一种基于移位窗口的Shifted Window的设计,这也是Swin Transformer最核心最关键的设计。Shifted Window如下图所示:
Shifted Window
从上图可以看到,常规的分窗大小为2x2,每个窗口大小为4x4个patches,但是这种分窗会使得不同Windows之间没有联系,所以要采用右边的Shifted Window的分窗操作。这种Shifted Window的分窗设计能够使原先不相交的Window之间产生联系,通过一个W-MSA的Swin Transformer block和一个SW-MSA的Swin Transformer block的组合,就能解决前述问题。W-MSA和SW-MSA的前向过程如下公式所示:
Shifted Window虽好,但又产生了其他问题。windows数量由2x2变为了3x3,并且还有一些小的window,比如2x2和4x2大小的window。针对这个问题,作者们又提出一种高效的批量计算方法 (efficient batch computation approach) 来进行处理。具体地,就是先通过一种特征图移位 (cyclic shift) 的方式来对窗口进行移动,移动后window由多个不相邻的sub-window组成,这里,作者又提出了一种attention mask的机制来实现前述设计,这使得能够在保持window个数不变的情况下,最后的计算结果是等价的。efficient batch computation approach如下图所示:
Cyclic Shift
这里需要专门把Cyclic Shift和Attention Mask挑出来说明一下。Cyclic Shift是一种简单的矩阵移位方法,具体代码实现时可以用PyTorch的torch.roll方法来实现,以3x3矩阵为例,如下图所示:
其中shifts参数为移动的步长,取负数表示逆向移动。dims表示要进行位移的维度。
Attention Mask
如果说Shifted Window是Swin Transformer的精华,那么Attention Mask则可以算作是Shifted Window的精华。Attention Mask主要干的事就是设置合理的mask,使得Shifted Window Attention在与Window Attention相同的窗口个数下,得到等价的计算结果。如下图所示,分别给SWA和WA加上index后,再计算window attention的时候,希望有相同index的qk进行计算,而忽略不同index的qk计算结果。
其中index 5本身不需要做变化,但剩下的(6,4)、(2,8)和(1,3,7,9)都混在一起了。所以就分别要对这三个设计一个mask。以(6,4)为例,最后设计的mask如下图所示。
小结
Swin Transformer的故事线大概是这样:因为图像视觉存在大尺度和ViT计算量大的问题,所以要提出Windows操作,把注意力计算限制在窗口内。但是分窗会造成不同窗口之间没有联系,所以又提出Shifted Window操作,使得不同窗口之间可以联系起来。但是Shifted Window又会使得窗口数变多,所以又提出efficient batch computation方法。efficient batch computation方法中的cyclic shift和attention mask是实现高效计算的核心操作。