《The Attention is all you need》的论文彻底改变了自然语言处理的世界，基于Transformer的架构成为自然语言处理任务的的标准。

```import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torchsummary import summary
from torchvision.transforms import Compose, Resize, ToTensor

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce```

```img = Image.open('penguin.jpg')

fig = plt.figure()
plt.imshow(img)
plt.show()```

```transform = Compose([
Resize((224, 224)),
ToTensor(),
])

x = transform(img)
x = x.unsqueeze(0)
print(x.shape)```

## 切分补丁和投影

```patch_size = 16
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)```

```class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
) # this breaks down the image in s1xs2 patches, and then flat them

def forward(self, x: Tensor) -> Tensor:
x = self.projection(x)
return x```

`torch.Size([1, 196, 768])`

## CLS 令牌和位置嵌入

```class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
) # this breaks down the image in s1xs2 patches, and then flat them

self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
x = torch.cat([cls_tokens, x], dim=1) #prepending the cls token
x += self.positions
return x```

Transformer 编码器 (Vaswani et al., 2017) 由多头自注意力和 MLP 块的交替层组成。在每个块之前应用Layer Norm (LN)，并在每个块之后添加残差连接。

## 注意力机制

```class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.qkv = nn.Linear(emb_size, emb_size * 3) # queries, keys and values matrix
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)

def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
# split keys, queries and values in num_heads
qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]
# sum up over the last axis
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len

fill_value = torch.finfo(torch.float32).min

scaling = self.emb_size ** (1/2)

att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values) # sum over the third axis
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)

return out```

## 残差连接

```class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x```

```class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, L: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, L * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(L * emb_size, emb_size),
)```

## Transformer 编码器块

```class TransformerEncoderBlock(nn.Sequential):
def __init__(self, emb_size: int = 768, drop_p: float = 0., forward_expansion: int = 4,
forward_drop_p: float = 0.,
**kwargs):

super().__init__(
nn.LayerNorm(emb_size),
nn.Dropout(drop_p)
)),
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, L=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))```

```patches_embedded = PatchEmbedding()(x)
print(TransformerEncoderBlock()(patches_embedded).shape)```

## Transformer编码器

```class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])```

```class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))```

## 整合所有的组件——VisionTransformer

```class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
)```

`print(summary(ViT(), (3,224,224), device='cpu'))`

```================================================================
Total params: 86,415,592
Trainable params: 86,415,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 364.33
Params size (MB): 329.65
Estimated Total Size (MB): 694.56
----------------------------------------------------------------```

https://avoid.overfit.cn/post/da052c915f4b4309b5e6b139a69394c1