Press "Enter" to skip to content

矩阵视角下的Transformer详解(附代码)

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

 

©PaperWeekly 原创 · 作者 | 孙裕道

 

单位 | 北京邮电大学博士生

 

研究方向 | GAN图像生成、情绪对抗样本生成

 

 

引言

 

Transformer 模型是 Google 团队在 2017 年 6 月由 Ashish Vaswani 等人在论文《Attention Is All You Need》所提出,当前它已经成为 NLP 领域中的首选模型。Transformer 抛弃了 RNN 的顺序结构,采用了 Self-Attention 机制,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入 Transformer 的 Seq2seq 模型在 NLP 的各个任务上都有了显着的提升。本文从矩阵视角下做了大量的图示目的是能够更加清晰地讲解 Transforme 的运行原理,以及相关组件的操作细节,文末还有完整可运行的代码示例。

 

 

注意力机制

 

Transformer 中的核心机制就是 Self-Attention。Self-Attention 机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显着性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。注意力机制主要目的就是对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。

 

2.1 Self-Attention

 

 

Self-Attention 工作原理如上图所示,给定输入 word embedding 向量,然后对于输入向量通过矩阵进行线性变换得到向量,向量,以及向量,即:

 

 

如果令矩阵,,,,则此时则有:

 

 

接着再利用得到的 Query 向量和 Key 向量计算注意力得分,论文中采用的注意力计算公式为点积缩放公式:

 

 

论文中假定向量的元素和 Query 向量的元素独立同分布,且令均值为,方差为,则此时注意力向量的第个分量  的均值为,方差具体的计算公式如下:

 

 

令注意力分数矩阵,则有:

 

 

注意分数向量经过层得到归一化后的注意力分布,即为:

 

最后利用得到的注意力分布向量和矩阵获得最后的输出

,则有:

 

 

令输出矩阵,则有:

 

 

2.2 Multi-Head Attention

 

 

Multi-Head Attention 的工作原理与 Self-Attention 的工作原理非常类似。为了方便图解可视化将 Multi-Head 设置为 2-Head,如果 Multi-Head 设置为 8-Head,则上图的的下一步的分支数为。

 

给定输入 word embedding 向量,然后对于输入向量通过矩阵进行第一次线性变换得到 Query 向量,Key向量,以及 Value 向量。

 

然后再对 Query 向量通过矩阵和进行第二次线性变换得到和,同理对 Key 向量通过矩阵和进行第二次线性变换得到和,对 Value 向量通过矩阵和进行第二次线性变换得到和,具体的计算公式如下所示:

 

 

令矩阵:

 

 

此时则有:

 

 

对于每个 Head 利用得到对于 Query 向量和 Key 向量计算对应的注意力得分,其中注意力向量的第个分量的计算公式为:

令注意力分数矩阵,

,则有:

 

 

注意分数向量经过 softmax 层得到归一化后的注意力分布,即为:

 

 

对于每一个 Head 利用得到的注意力分布向量和 Value 矩阵获得最后的输出,则有:

 

 

两个 Head 的的向量按照如下方式拼接在一起,则有:

 

 

给定参数矩阵,则输出矩阵为:

 

 

综上所述则有:

 

 

2.3 Mask Self-Attention

 

如下图左半部分所示,Self-Attention 的输出向量综合了输入向量的全部信息,由此可见,Self-Attention 在实际编程中支持并行运算。如下图右半部分所示,Mask Self-Attention 的输出向量只利用了已知部分输入的向量的信息。例如,只是与有关;与和有关;与,和有关;与,,和有关。Mask Self-Attention 在 Transformer 中被用到过两次。

 

Transformer 的 Encoder 中如果输入一句话的 word 长度小于指定的长度,为了能够让长度一致往往会用 0 进行填充,此时则需要用 Mask Self-Attention 来计算注意力分布。

 

Transformer 的 Decoder 的输出是有时序关系的,当前的输出只与之前的输入有关,所以此时算注意力分布时需要用到 Mask Self-Attention。

 

 

 

Transformer模型

 

以上对 Transformer 中的核心内容即自注意力机制进行了详细解剖,接下来会对 Transformer 模型架构进行介绍。Transformer 模型是由 Encoder 和 Decoder 两个模块组成,具体的示意图如下所示,为了能够对 Transformer 内部的操作细节进行更清晰的展示,下图以矩阵运算的视角对 Transformer 的原理进行讲解。

 

Encoder 模块操作的具体流程如下所示:

 

Encoder 的输入由两部分组成分别是词编码矩阵和位置编码矩阵,其中表示句子数目,表示一句话单词的最大数目,表示的是词向量的维度。位置编码矩阵表示的是每个单词在一句里的所有位置信息,因为 Self-Attention 计算注意力分布的时候只能给出输出向量和输入向量之间的权重关系,但是不能给出词在一句话里的位置信息,所以需要在输入里引入位置编码矩阵。位置编码向量生成方法有很多。一种比较简单粗暴的方式就是根据单词在句子中的位置生成一个 one-hot 的位置编码;还有的方法是将位置编码当成参数进行训练学习;在该论文里是利用三角函数对位置进行编码,具体的公式如下所示:

 

 

其中表示的是位置编码向量,表示词在句子中的位置,表示编码向量的位置索引。

 

输入矩阵通过线性变换生成矩阵,,。在实际编程中是将输入直接赋值给,,。如果输入单词长度小于最大长度并来填充的时候,还要相应引入 Mask 矩阵。

 

将矩阵,,输入到 Multi-Head Attention 模块中进行注意分布的计算得到矩阵,计算公式为:

 

 

具体的计算细节参考上文关于 Multi-Head Attention 原理的讲解不在这里赘述。然后将原始输入与注意力分布进行残差计算得到输出矩阵。

 

对矩阵进行层归一化操作得到,具体的计算公式为:

 

 

将输入到全连接神经网络中得到,然后再让全连接神经网络的输入与输出进行残差计算得到,接着对进行层归一化操作。

 

以上是一个 Block 的操作原理,将个 Block 进行堆叠就组成了 Encoder 的模块,得到的最后输出为。这里需要注意的是 Encoder 模块中的各个组件的操作顺序并不是固定的,也可以先进行归一化操作,然后再计算注意力分布,再归一化,再预测等。

 

Decoder 模块操作的具体流程如下所示:

 

Decoder 的输入也由两部分组成分别是词编码矩阵和位置编码矩阵。因为 Decoder 的输入是具有时顺序关系的(即上一步的输出为当前步输入)所以还需要输入 Mask 矩阵以便计算注意力分布。

 

输入矩阵通过线性变换生成矩阵,,。在实际编程中是将输入直接赋值给,,。如果输入单词长度小于最大长度并 0 来填充的时候,还要相应引入 Mask 矩阵。

 

将矩阵,,以及 Mask 矩阵输入到 Mask Multi-Head Attention 模块中进行注意分布的计算得到矩阵,计算公式为:

 

 

具体的计算细节参考上文关于 Mask Self-Attention 的讲解不在这里赘述。然后将原始输入与注意力分布进行残差计算得到输出矩阵。

 

接着再对矩阵进行层归一化操作得到。

 

Encoder 的输出通过线性变换得到和,进行线性变换得到,利用矩阵和和进行交叉注意力分布的计算得到,计算公式为:

 

这里的交叉注意力分布综合 Encoder 输出结果和 Decoder 中间结果的信息。实际编程编程中将直接赋值给和,直接赋值给。然后将与注意力分布进行残差计算得到输出矩阵。

 

接着对进行层归一操作得到,再将输入到全连接神经网络中得到,接着再做一步残差操作得到,最后再进行一层归一化操作。

 

以上是一个 Block 的操作原理,将个 Block 进行堆叠就组成了 Decoder 的模块,得到的输出为。然后在词汇字典中找到当前预测最大概率的单词,并将该单词词向量作为下一阶段的输入,重复以上步骤,直到输出“end”字符为止。

 

 

 

程序代码

 

Transformer 具体的代码示例如下所示。根据上文中 Multi-Head Attention 原理示例图可知,严格来看 Multi-Head Attention 在求注意分布的时候中间其实是有两步线性变换。给定输入向量第一步线性变换直接让向量赋值给,,,这一过程以下程序中有所体现,在这里并不会产生歧义。

 

第二步线性变换产生多 Head,假设的时候,按理说要与个矩阵进行线性变换得到个,同理要与个矩阵进行线性变换得到个,要与个矩阵进行线性变换得到个,如果按照这个方式在程序实现则需要定义 24 个权重矩阵,非常的麻烦。

 

以下程序中有一个简单的权重定义方法,通过该方法也可以实现以上多Head的线性变换,以向量为例:

 

首先将向量进行截断分成个向量,即为:

 

 

其中是的第个截断向量,是单位矩阵,是零矩阵。

 

然后对用相同的权重矩阵进行线性变换,此时可以发现,训练过程的时候只需要更新权重矩阵即可,而且可以进行多 Head 线性变换,个权重矩阵可以表示为:

 

 

其中权重矩阵。

 

    import torch
    import torch.nn as nn
    import os
    class SelfAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(SelfAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads
            assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        def forward(self, values, keys, query, mask):
            N =query.shape[0]
            value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]
            # split embedding into self.heads pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)
            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)
            energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
            # queries shape: (N, query_len, heads, heads_dim)
            # keys shape : (N, key_len, heads, heads_dim)
            # energy shape: (N, heads, query_len, key_len)
            if mask is not None:
                energy = energy.masked_fill(mask == 0, float("-1e20"))
            attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)
            out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
            # attention shape: (N, heads, query_len, key_len)
            # values shape: (N, value_len, heads, heads_dim)
            # (N, query_len, heads, head_dim)
            out = self.fc_out(out)
            return out
    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion):
            super(TransformerBlock, self).__init__()
            self.attention = SelfAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)
            self.feed_forward = nn.Sequential(
                nn.Linear(embed_size, forward_expansion*embed_size),
                nn.ReLU(),
                nn.Linear(forward_expansion*embed_size, embed_size)
            )
            self.dropout = nn.Dropout(dropout)
        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out
    class Encoder(nn.Module):
        def __init__(
                self,
                src_vocab_size,
                embed_size,
                num_layers,
                heads,
                device,
                forward_expansion,
                dropout,
                max_length,
            ):
            super(Encoder, self).__init__()
            self.embed_size = embed_size
            self.device = device
            self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)
            self.layers = nn.ModuleList(
                [
                    TransformerBlock(
                        embed_size,
                        heads,
                        dropout=dropout,
                        forward_expansion=forward_expansion,
                        )
                    for _ in range(num_layers)]
            )
            self.dropout = nn.Dropout(dropout)
        def forward(self, x, mask):
            N, seq_length = x.shape
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
            out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
            for layer in self.layers:
                out = layer(out, out, out, mask)
            return out
    class DecoderBlock(nn.Module):
        def __init__(self, embed_size, heads, forward_expansion, dropout, device):
            super(DecoderBlock, self).__init__()
            self.attention = SelfAttention(embed_size, heads)
            self.norm = nn.LayerNorm(embed_size)
            self.transformer_block = TransformerBlock(
                embed_size, heads, dropout, forward_expansion
            )
            self.dropout = nn.Dropout(dropout)
        def forward(self, x, value, key, src_mask, trg_mask):
            attention = self.attention(x, x, x, trg_mask)
            query = self.dropout(self.norm(attention + x))
            out = self.transformer_block(value, key, query, src_mask)
            return out
    class Decoder(nn.Module):
        def __init__(
                self,
                trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length,
        ):
            super(Decoder, self).__init__()
            self.device = device
            self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)
            self.layers = nn.ModuleList(
                [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)]
                )
            self.fc_out = nn.Linear(embed_size, trg_vocab_size)
            self.dropout = nn.Dropout(dropout)
        def forward(self, x ,enc_out , src_mask, trg_mask):
            N, seq_length = x.shape
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
            x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))
            for layer in self.layers:
                x = layer(x, enc_out, enc_out, src_mask, trg_mask)
            out =self.fc_out(x)
            return out
    class Transformer(nn.Module):
        def __init__(
                self,
                src_vocab_size,
                trg_vocab_size,
                src_pad_idx,
                trg_pad_idx,
                embed_size = 256,
                num_layers = 6,
                forward_expansion = 4,
                heads = 8,
                dropout = 0,
                device="cuda",
                max_length=100
            ):
            super(Transformer, self).__init__()
            self.encoder = Encoder(
                src_vocab_size,
                embed_size,
                num_layers,
                heads,
                device,
                forward_expansion,
                dropout,
                max_length
                )
            self.decoder = Decoder(
                trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length
                )
            self.src_pad_idx = src_pad_idx
            self.trg_pad_idx = trg_pad_idx
            self.device = device
        def make_src_mask(self, src):
            src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
            # (N, 1, 1, src_len)
            return src_mask.to(self.device)
        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
                N, 1, trg_len, trg_len
            )
            return trg_mask.to(self.device)
        def forward(self, src, trg):
            src_mask = self.make_src_mask(src)
            trg_mask = self.make_trg_mask(trg)
            enc_src = self.encoder(src, src_mask)
            out = self.decoder(trg, enc_src, src_mask, trg_mask)
            return out
    if __name__ == '__main__':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(device)
        x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)
        trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)
        src_pad_idx = 0
        trg_pad_idx = 0
        src_vocab_size = 10
        trg_vocab_size = 10
        model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
        out = model(x, trg[:, : -1])
        print(out.shape)

 

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。