浅析transformer的结构

本文最后更新于 2024年7月29日晚上7点17分

浅析transformer结构

transformer的结构有点复杂,中间涉及很多重要结构。这里会大致理解每个结构的公式原理,详细的解释会在每个细节帖子中展现。

transformer是NLP模型很常见的结构。
大语言模型的处理流程,通俗理解是这样:
假设模型现在已经训练好,预测过程:输入文本,经过模型得到预测的下一个字,然后把这个字和之前内容输入,得到下一个字预测。迭代最终得到预测的回答。

transformer的结构和原理:

transformer结构

首先输入一段文本或者句子,经过tokenizer把文本进行tokenization,切分为一个个tokens,然后tokens通过embedding向量化,然后添加位置编码,得到向量。

向量输入给transformer encoder,得到attention block输出的Q和K给到decoder。

然后“Outputs”部分的输入本文经过input embedding+positional encoding得到向量,经过attention模块得到的V和之前ecoder的Q和K一起输入cross attention模块。最终得到的预测向量经过linear+softmax得到预测概率向量,概率最大的token为预测token。

train时,预测token向量会和真实token向量计算loss,更新模型参数。

transformer可以分为以下几个模块:

  1. tokenizer部分:把文本转为向量
  2. positional encoding:给向量添加位置编码
  3. attention模块
  4. 训练和预测的细节

1. tokenizer

tokenizer的作用就是把文字转换为向量,能方便输入给模型进一步训练处理。

tokenizer的细节看这个博客:点击这里查看tokenizer原理介绍

2. positional encoding

位置编码是给文本序列添加位置信息。因为一段文本,同一个单词在不同位置的含义是不一样的,记录单词在文本中的位置信息很重要。

位置信息的公式,在“attention is all you need”论文中是这样的:
positional encoding其中,pos是这个token在序列中的位置,例如class token的pos=0;
假设这个序列PE总共有N个token,每个token的长度是K,则序列的形状是N*K,pos的范围是0~N-1;2i和2i+1是每个token向量中的index,范围是0~K-1.
通俗的,PE是一个矩阵,有N行,每一行表示一个token向量;(pos, 2i)位置就是第pos个token,在index=2i位置的元素。

positional encoding的细节看这里:位置编码的原理

3.Attention结构

单头注意力机制

attention结构是transformer当中最重要的部分。论文中attention的公式是这样的:

其中,Q,K,V是向量组${a_1,a_2,…a_n}$乘以矩阵$W^q, W^k, W^v$得到的。

多头注意力机制:

多头注意力机制,会对每个Q K V分别乘以$W_i^q,W_i^k,W_i^v$,得到n_head组的${Q_i,K_i,V_i}$,其中n_head是head的个数,一般为8.
公式如下:

掩码注意力机制

输入矩阵X先上面的多头注意力机制一样,得到Q K V矩阵,然后$\frac{QK^T}{\sqrt{d)k}}$会和掩码矩阵相乘,再经过softmax,然后和V相乘。公式如下:

掩码矩阵就是$M=[a_{ij}], \quad a_{ij}=0 \quad if \quad i < j $.

交叉注意力机制

cross attention交叉注意力机制。
在transformer的decoder block中,第一个是multi-head attention多头注意力机制,其中Q K V都是Outputs向量得到的,但是第二个multi-head attention的K和V是encoder得到的,Q是decoder得到的。

通过encoder的输出向量C计算得到K和V,再根据decoder block输出的Z计算出Q,后续Q K V的计算和上面一样。
好处是:decoder的每位单词可以利用到encoder所有单词信息。

上看只是浅析了各种attention的公式,具体的详细介绍各种attention结构原理:
详细理解attention的原理

关于attention结构的pytorch面试题,看这里,手动pytorch搭建attention结构:
pytorch面试题:实现attention结构

残差链接&LN

transformer结构中有残差链接residual connection,并且使用了Layer Norm。

激活函数是GELU,优化器是Adam。损失函数是Cross entropy loss交叉熵损失函数。

machine learning中的各种激活函数:
激活函数大全

machine learning中的各种优化器optimizer,及其作用:
优化器介绍I——BGD/SGD/MBGD, 优化器介绍II——动量&自适应学习率

参考:

transformer论文:Attention Is All You Need
https://arxiv.org/pdf/1706.03762


浅析transformer的结构
https://kangkang37.github.io/2024/07/05/transformer/
作者
kangkang
发布于
2024年7月5日
许可协议