Press "Enter" to skip to content

DaViT:双注意力Vision Transformer

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

码字不易,欢迎点赞!

 

DaViT 是香港大学和微软等研究机构在近期新发布的一个Vision Transformer模型,这个工作的创新之处是提出了一种 双注意力机制(dual attention) 来高效地实现全局建模,其中最大的模型DaViT-Giant在ImageNet1K数据集上达到了 90.4%的Top1 Accuracy ,超过了之前的SwinV2(90.17%)。

 

这里的双attention是从两个正交的角度来进行self-attention:一是对 spatial tokens 进行self-attention,此时空间维度( )定义了tokens的数量,而channel维度( )定义了tokens的特征大小,这其实也是ViT最常采用的方式;二是对 channel tokens 进行self-attention,这和前面的处理完全相反,此时channel维度( )定义了tokens的数量,而空间维度( )定义了tokens的特征大小。可以看出两种self-attention完全是相反的思路。为了减少计算量,两种self-attention均采用分组的attention:对于 spatial tokens 而言,就是在空间维度上划分成不同的windows,这就是Swin中所提出的window attention,论文称之为 spatial window attention ;而对于 channel tokens ,同样地可以在channel维度上划分成不同的groups,论文称之为 channel group attention 。这两种attention如下图所示:

 

两种attention能够实现互补: spatial window attention 能够提取windows内的局部特征,而 channel group attention 能学习到全局特征,这是因为每个channel token在图像空间上都是全局的。 DaViT和Swin一样采用金字塔结构 ,共包含4个stage,每个stage包含一定数量的 dual attention block ,这个block就是将两种attention(还包含FFN) 交替地堆叠在一起 ,如下所示:

 

下面我们来介绍两种attention的具体实现细节。在开始之前,先回顾一下global self-attention。假定共有 个patchs(特征图大小 ),每个patch的特征大小为 (channel数量),所有的patchs的特征 ,那幺self-attention的计算如下所示:

 

 

这里self-attention的head数量为 ,第 个head的query,key和value通过线性投射得到: ,它们的维度大小为 ,其中 。这里省略attention之后的线性投射,所以共有4个线性投射,总的计算复杂度为 ,而attention的计算复杂度为 ,所以最终总的计算复杂度为 。可以看到计算量与patchs总量的平方成正比,而patchs总量和图像大小线性相关,当图像大小增加时,global self-attention的计算量将大幅度增加。

 

采用 spatial window attention 可以减少上述计算量,这里将patchs按照空间结构划分成 个window(比如7×7大小),每个window的patchs记为 ,这里 。然后每个window里面的patchs单独进行self-attention:

 

 

此时每个window attention的attention部分的计算复杂度为 ,总的计算复杂度为 。如果固定window大小,那幺window attention的计算量就和patchs总量成线性关系。虽然window attention降低了计算量,但是也变成了一种local attention,因为不同的windows之间并没有信息交换,Swin通过复杂的shift window来实现这种信息交换。

 

而 channel group attention 可以实现全局的attention。首先将channel分成 个group,每个group的channel数量为 ,这里有 ,那幺 channel group attention 的计算如下所示:

 

 

这里的 分别是query,key和value,注意这里还是按照channel维度来进行线性投射得到,而不是在spatial维度,因为这样权重 是和图像的大小是无关的,模型可以适应任何大小的图像作为输入。在attention计算时,只需要将query,key和value的维度进行反转,即维度变成 ,就可以实现channel attention了,由于这里我们希望attention在空间维度是全局的,所以不采用multi-head attention,或者说只用一个head,另外注意这里的scale因子采用的是 ,而不是 ,因为后者是图像大小相关的。同样地,channel group attention也包含4个线性投射,其计算复杂度也是 ,而attention部分的计算复杂度为 。channel group attention的主要实现代码如下所示:

 

class ChannelAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads # 这里的num_heads实际上是num_groups
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B, N, C = x.shape
        # 得到query,key和value,是在channel维度上进行线性投射
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k = k * self.scale
        attention = k.transpose(-1, -2) @ v # 对维度进行反转
        attention = attention.softmax(dim=-1)
        x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

 

这样对于dual attention,其总的attention计算量为 ,而线性投射的计算量为 ,当FFN的expand ratio为4时,FFN的计算量为 。

 

此外,dual attention 采用depth-wise conv来实现位置编码 ,以window attention为例,分别在self-attention和FFN之前均插入一个3×3 depth-wise conv,代码如下所示(channel attention也是如此):

 

class SpatialBlock(nn.Module):
    r""" Windows Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, num_heads, window_size=7,
                 mlp_ratio=4., qkv_bias=True, drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 ffn=True, cpe_act=False):
        super().__init__()
        self.dim = dim
        self.ffn = ffn
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        # conv位置编码
        self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
                                  ConvPosEnc(dim=dim, k=3, act=cpe_act)])
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim,
            window_size=to_2tuple(self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        if self.ffn:
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer)
    def forward(self, x, size):
        H, W = size
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        shortcut = self.cpe[0](x, size) # depth-wise conv
        x = self.norm1(shortcut)
        x = x.view(B, H, W, C)
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape
        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)
        # merge windows
        attn_windows = attn_windows.view(-1,
                                         self.window_size,
                                         self.window_size,
                                         C)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)
        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)
        x = self.cpe[1](x, size) # 第2个depth-wise conv
        if self.ffn:
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x, size

 

DaViT也采用金字塔结构:首先是一个patch embedding层,采用stride=4的7×7 conv,然后是4个stages,各个stages通过stride=2的2×2 conv来进行降采样。其中DaViT-Tiny,DaViT-Small和DaViT-Base三个模型的配置如下所示:

 

注意,这里window attention的num_heads和channel attention的num_groups采用相同的配置。

 

除了上面3个基础模型,论文还增加了另外3个更大的模型:

DaViT-Large: C = 192, L = {1, 1, 9, 1}, Ng = Nh = {6, 12, 24, 48}
DaViT-Huge: C = 256, L = {1, 1, 9, 1}, Ng = Nh = {8, 16, 32, 64}
DaViT-Giant: C = 384, L = {1, 1, 12, 3}, Ng = Nh = {12, 24, 48, 96}

其中DaViT-Giant参数量达到了4B。

 

下表为DaViT和其它模型在ImageNet1K上的分类结果对比,可以看到DaViT在同样的参数量和FLOPs超过其它模型如Swin,Focal和PVTv2。其中最大的模型DaViT-Giant在1.5B的图像-文本对数据集上预训练后,可以在ImageNet1K数据集上达到90.4%。

 

在下游检测任务上,DaViT也超过其它模型,DaViT的一个明显优势当模型增大时,性能是稳步提升的,但是其它模型如Swin和Focal,Base模型效果反而比Small模型要差一些。

 

同样地,在语义分割数据集ADE20K上,DaViT也表现了较好的性能:

 

总体来看,DaViT是一个非常好的工作,相比其它local attention的模型,DaViT从设计上保证了全局建模。DaViT的双注意力机制和谷歌的MLP-Mixer有点类似,MLP-Mixer中包括两个不同的MLP:channel-mixing MLP和token-mixing MLP,其中token-mixing MLP也是通过反转维度来实现全局信息交互的,不过MLP的缺点是权重矩阵依赖输入大小,而attention没有这个问题。

 

DaViT的detectron2版本见: GitHub – xiaohu2015/nndet2

 

参考

https:// github.com/dingmyu/davi t
https:// arxiv.org/abs/2204.0364 5

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。