## 注意力机制

1. 在输入信息上计算注意力分布

1. 根据注意力分布计算输入信息的加权平均

## 软注意力机制

α i =p(z=i | X, q)

= softmax(s(x i , q))

K与查询向量q通过注意力打分函数s()对每个v计算出对应的α值，然后Σα i v i 即可。

Transformer中使用的是 自注意力机制（self-attention） ，那幺什幺是自注意力机制呢？

### 代码实现

```import torch
import torch.nn as nn
import math
import torch.nn.functional as F
# 加性模型
class attention1(nn.Module):
def __init__(self, q_size, k_size, v_size, seq_len):
# q、k、v的维度，seq_len每句话中词的数量
super(attention1, self).__init__()
self.linear_v = nn.Linear(v_size, seq_len)
self.linear_W = nn.Linear(k_size, k_size)
self.linear_U = nn.Linear(q_size, q_size)
self.tanh = nn.Tanh()

def forward(self, query, key, value, dropout=None):
key = self.linear_W(key)
query = self.linear_U(query)
k_q = self.tanh(query + key)
alpha = self.linear_v(k_q)
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention_1 = attention1(100, 100, 100, 10)
q = k = v = torch.randn((8,10,100)) # 可以理解为有8句话，每句话有10个词，每个词用100维的向量来表示
out, attn = attention_1(q, k, v)
print(out.shape)
print(attn.shape)```

```import torch
import torch.nn as nn
import math
import torch.nn.functional as F
# 点积模型
class attention2(nn.Module):
def __init__(self):
super(attention2, self).__init__()
def forward(self, query, key, value, dropout=None):
alpha = torch.bmm(query, key.transpose(-1, -2))
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention_2 = attention2()
q = k = v = torch.randn((8,10,100))
out, attn = attention_2(q, k, v)
print(out.shape)
print(attn.shape)```

transformer用的就是这种注意力模型，不过是多头，下面会讲到

```import torch
import torch.nn as nn
import math
import torch.nn.functional as F
# 缩放点积模型
class attention3(nn.Module):
def __init__(self):
# q、k、v的维度，seq_len每句话中词的数量
super(attention3, self).__init__()
def forward(self, query, key, value, dropout=None):
d = k.size(-1)
alpha = torch.bmm(query, key.transpose(-1, -2)) / math.sqrt(d)
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention_3 = attention3()
q = k = v = torch.randn((8,10,100))
out, attn = attention_3(q, k, v)
print(out.shape)
print(attn.shape)```

```import torch
import torch.nn as nn
import math
import torch.nn.functional as F
# 双线性模型
class attention4(nn.Module):
def __init__(self, x_size):
# seq_len每句话中词的数量
super(attention4, self).__init__()
self.linear_W = nn.Linear(x_size, x_size)
def forward(self, query, key, value, dropout=None):
alpha = torch.bmm(query, self.linear_W(key).transpose(-1, -2))
alpha = F.softmax(alpha, dim=-1)
out = torch.bmm(alpha, value)
return out, alpha
attention_4 = attention4(100)
q = k = v = torch.randn((8,10,100))
out, attn = attention_4(q, k, v)
print(out.shape)
print(attn.shape)```

## 硬注意力机制

1. 选取最高概率的一个输入向量

1. 通过在注意力分布式上随机采样的方式实现（类似掷骰子）

## 多头注意力机制

### 代码实现

```import torch
import torch.nn as nn
import math
import torch.nn.functional as F
# 缩放点积模型
class attention3(nn.Module):
def __init__(self):
super(attention3, self).__init__()
def forward(self, query, key, value, dropout=None):
d = key.size(-1)
alpha = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d)
alpha = F.softmax(alpha, dim=-1)
out = torch.matmul(alpha, value)
return out, alpha

assert embedding_size % head == 0 # 得整分
self.W_K = nn.Linear(embedding_size, embedding_size)
self.W_Q = nn.Linear(embedding_size, embedding_size)
self.W_V = nn.Linear(embedding_size, embedding_size)
self.fc = nn.Linear(embedding_size, embedding_size)
self.dropout = nn.Dropout(dropout)
self.attention = attention3()
def forward(self, query, key, value):
batch_size = query.size(0)
# 转换成多头，一次矩阵乘法即可完成
query = self.W_Q(query).view(batch_size, self.head, -1, self.d_k)
key = self.W_K(key).view(batch_size, self.head, -1, self.d_k)
value = self.W_V(value).view(batch_size, self.head, -1, self.d_k)
out, alpha = self.attention(query, key, value, self.dropout)
out = out.view(batch_size, -1, self.d_k * self.head)
out = self.fc(out)
return out, alpha
c = torch.randn((4,5,20))
out, alpha = m(c,c,c)
print(out.shape)
print(alpha.shape)```

```query = self.W_Q(query).view(batch_size, self.head, -1, self.d_k)
query = self.W_Q(query).view(batch_size, -1,self.head, self.d_k).transpose(1, 2)```

https://www.bilibili.com/video/BV1DK411M73n?p=9&vd_source=f57738ab6bbbbd5fe07aae2e1fa1280f