欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale

发布时间:2023/12/15 编程问答 35 豆豆
生活随笔 收集整理的这篇文章主要介绍了 【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

文章目录

    • 一、背景和动机
    • 二、方法
    • 三、效果
    • 四、Vision Transformer 学习到图像的哪些特征了
    • 五、代码

代码链接:https://github.com/lucidrains/vit-pytorch

论文连接:https://openreview.net/pdf?id=YicbFdNTTy

一、背景和动机

Transformer 在 NLP 领域取得了很好的效果,但在计算机视觉领域还没有很多应用,所以作者想要借鉴其在 NLP 中的方法,在计算机视觉的分类任务中进行使用。

二、方法


由于 Transformer 在 NLP 中使用时,都是接受一维的输入,而图像是二维的结构,所以需要先把图像切分成大小相等的patch,然后再编码成一个序列,送入 Transformer

Vision Transformer 的过程:

1、输入图像分割成 patch,并使用可学习的线性变换将其拉伸成 D 维,为输入 Transformer 做准备

  • 首先,输入图像为 x∈RH×W×Cx\in R^{H \times W \times C}xRH×W×C,将其 reshape 为 xp∈RN×(P2C)x_p\in R{N \times (P^2 C)}xpRN×(P2C),切分后的 patch 大小为 P×PP \times PP×PN=HW/P2N=HW/P^2N=HW/P2 为 patch 的个数。这里,NNN 也掌握着 Transformer 的效率。

  • Transformer 在每层都使用维度大小为 DDD 的向量输入,所以在送入 Transformer 之前,会使用可学习的线性影射(公式1)来将拉平的 patch 信息转换为 D 维。

  • 输入:3x224x224

  • patch 大小 PPP:32x32

  • patch 个数 NNN:7x7=49

  • D:128

2、给 D 维的 Transformer 输入后面连接一个 class token

将输入编码成 B×N×DB\times N \times DB×N×D 输入 Transformer 之前,会给 NNN 这个维度增加一个 class token,变成 N+1N+1N+1 维,这个 class token 是一个大小为 1×1×D1\times 1 \times D1×1×D 的可学习向量,表示 Transformer 的 Encoder 的输出(zL0z_L^0zL0),也就是作为图像的特征表达 y。

3、给上面的结果加上位置编码

添加位置编码能够保留图像的位置信息,作者使用可学习的 1D 位置编码(因为从文献来看,使用 2D 编码也没有带来理想的效果)。ViT 中的位置编码是随机生成可学习参数,没有做过多设计,这样的设计。

4、送入 Transformer Encoder

这里的 Encoder 由多层的 multiheaded self-attention(MSA)和 MLP 组成,每层之前都会使用 Layernorm,每个 block 之后都会使用残差连接。

归一化方法:

Transformer 中一般都使用 LayerNorm,LayerNorm 和 BatchNorm 的区别如下图所示:

  • LayerNorm:对一个 batch 的所有通道进行归一化(均值为 0,方差为 1)
  • BatchNorm:对一个通道的所有 batch 进行归一化(均值为 0,方差为 1)

三、效果

  • 使用中等大小的数据集(如 ImageNet),Transformer 比 ResNet 的效果稍微差点,作者认为原因在于 Transformer 缺少了 CNN 中的归纳偏置(平移不变性和位置),泛化的也较差
  • 使用大型数据集训练时(约 14M-300M images),作者发现大型的数据训练会胜过归纳偏置带来的效果,ViT 在使用了大型数据集预训练(ImageNet-21k 或 in-house JFT-300M)然后迁移到其他任务时,效果优于 CNN。

归纳偏置是什么?

归纳偏置可以理解成在算法在设计之初就加入的一种人为偏好,将某种方式的解优于其他解,既包含低层数据分布假设,也包含模型设计。

在深度学习时代,这种归纳性偏好更为明显。比如深度神经网络结构就偏好性的认为,层次化处理信息有更好效果;卷积神经网络认为信息具有空间局部性(locality),可以用滑动卷积共享权重方式降低参数空间;反馈神经网络则将时序信息考虑进来强调顺序重要性;图网络则是认为中心节点与邻居节点的相似性会更好引导信息流动。不同的网络结构创新就体现了不同的归纳性偏。

之前计算机视觉任务大都依赖于 CNN,CNN 有两个内置的归纳偏置:

  • 局部相关性
  • 权重共享

但基于注意力模型的 Transformer 最小化了归纳偏置,所以在大数据集上进行训练时,效果甚至可以超过 CNN,但小数据集上因为缺少了这种归纳偏置,所以难以总结到有意义的特征。

CNN 有较好的归纳偏置,所以数据少的时候也能实现好的效果,但数据量大的时候,这些归纳偏置就会限制其效果,但 Transformer 不会被其限制,所以在大数据集上表现更好一些。


四、Vision Transformer 学习到图像的哪些特征了

为了理解 Transformer 是如何学习到图像特征的,作者分析了其内部的特征表达:

  • Transformer 的第一层将 flattened patch 线性影射到了一个低维空间(公式 1),图 7 左侧可视化了前几个主要的学习到的 embedding filters,这些组件类似于每个patch内精细结构的低维表示的可信基函数。
  • 线性投影之后,加上位置编码,图 7 中间展示了模型学习了在位置嵌入相似度下对图像内距离进行编码,即离得近的 patches 更趋向于有相似的位置嵌入,然后就有了 row-column 结构,同一行或同一列的 patches 有相似的嵌入。
  • 自注意力机制能够提取整幅图像的信息,作者为了探究这种注意力给网络起了多大的作用,根据注意力的权重计算了其在空间中的平均距离(图 7 右),这种”注意力距离”类似于 CNN 中的感受野的大小。作者注意到,一些 heads 趋向于关注最底层的大部分图像,这表明模型确实使用了全局整合信息的能力。其他注意头在较低层上的注意距离一直很小。这种高度位置化的注意在混合模型中不太明显,这些模型在Transformer之前应用了ResNet(图7,右),这表明它可能具有与cnn中的低层卷积层类似的功能。注意距离随网络深度的增加而增加。从全局来看,发现模型关注与分类语义相关的图像区域(图6)。



上面中间的热力图可视化,某个位置和自己的余弦相似度肯定是最高的,然后和同行同列相似度次高,其他位置较低,这也能基本想通,因为位置本来就表示的某个像素在图像中的某行某列,符合可视化结果。

五、代码

总体过程:

  • 输入原图:[1, 3, 224, 224]
  • patch 编码:[1, 49, 1024]
  • cls_token:[1, 50, 1024]
  • 位置编码:[1, 50, 1024]
  • Transformer:Attention + FeedForward [1, 50, 1024]
  • 取第一组向量(或均值)作为全局特征:[1, 1024]
  • MLP 输出预测类别:[1, 1000],1000为类别数
import torch from torch import nnfrom einops import rearrange, repeat from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):import pdb; pdb.set_trace()qkv = self.to_qkv(x).chunk(3, dim = -1) # len=3, qkv[0].shape=qkv[1].shape=qkv[2].shape=[1, 50, 1024]q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # q.shape=k.shape=v.shape=[1, 16, 50, 64]dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # [1, 16, 50, 50]attn = self.attend(dots) # [1, 16, 50, 50]out = torch.matmul(attn, v) # [1, 16, 50, 64]out = rearrange(out, 'b h n d -> b n (h d)') # [1, 50, 1024]return self.to_out(out) # [1, 50, 1024]class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))def forward(self, x):for attn, ff in self.layers:# attn: attention# ff: feedforwardx = attn(x) + xx = ff(x) + xreturn xclass ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()import pdb;pdb.set_trace()image_height, image_width = pair(image_size) # 224, 224patch_height, patch_width = pair(patch_size) # 32, 32assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width) # 7*7=49patch_dim = channels * patch_height * patch_width # 3072assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(patch_dim, dim), # Linear(in_features=3072, out_features=1024, bias=True))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # [1, 50, 1024]self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # [1, 1, 1024]self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):# img: [1, 3, 224, 224]x = self.to_patch_embedding(img) # [1, 49, 1024]b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) # [1, 1, 1024] ->[b, 1, 1024]x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)] # [1, 50, 1024]x = self.dropout(x) # [1, 50, 1024]x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] # [1, 1024]x = self.to_latent(x)return self.mlp_head(x)

调用 ViT 的方法:

pip install vit-pytorch import torch from vit_pytorch import ViTv = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1 )img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)

Patch embedding 结构:

Sequential((0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)(1): Linear(in_features=3072, out_features=1024, bias=True) )

Transformer 结构:

Transformer((layers): ModuleList((0): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(1): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(2): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(3): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(4): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(5): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))) )

前传方式:

def forward(self, x):for attn, ff in self.layers:# attn: attention# ff: feedforwardx = attn(x) + xx = ff(x) + xreturn x

MLP 结构:

Sequential((0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(1): Linear(in_features=1024, out_features=1000, bias=True) )

总结

以上是生活随笔为你收集整理的【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。