LLaMA(Large Language Model Meta AI)是发布于 2023 年 2 月 的开源预训练大型语言模型,与 GPT 等生成模型类似,LLaMA 也只使用了 Transformer 的解码器部分。
1、LLaMA 1
LLaMA 是一组从 70 亿到 650 亿参数的基础语言模型,特别的,LLaMA-13B 在大多数基准测试中优于 GPT-3(1750亿参数)。
1.1、预训练数据
数据及其在训练集中的比例:
- English CommonCrawl [67%]: 使用 CCNet 流程对 2017 年至 2020 年间的五个 CommonCrawl 数据集进行了预处理。该过程在行级别去重,使用 fastText 线性分类器进行语言识别以移除非英文页面,并使用 n-gram 语言模型过滤低质量内容。此外,训练了一个线性模型,用于将维基百科中的参考页面与随机抽取的页面进行分类,并丢弃那些未被分类为参考页面的页面
- C4 [15%]: 使用多样化的预处理 CommonCrawl 数据集可以提高性能,因此,将公开可用的 C4 数据集纳入数据中。C4 的预处理也包括去重和语言识别步骤:与 CCNet 的主要区别在于质量过滤,C4 主要依赖一些启发式方法,如页面中是否存在标点符号或网页中的单词和句子的数量
- Github [4.5%]: 使用 Google BigQuery 上可用的公共 GitHub 数据集。根据行长度或字母数字字符的比例使用启发式方法过滤低质量文件,并使用正则表达式移除样板文件(如标题)。最后,在文件级别对结果数据集进行了去重,去除了完全相同的文件
- Wikipedia [4.5%]: 添加了 2022 年 6 月至 8 月期间的维基百科转储,涵盖了 20 种使用拉丁字母或西里尔字母的语言。并对数据进行处理以去除超链接、评论和其他格式化的模板
- Gutenberg and Books3 [4.5%]: 训练数据集中包含了两个书籍语料库:古腾堡计划(Gutenberg Project)以及 ThePile 中的 Books3 部分。在书籍级别进行了去重,去除了那些内容重叠超过 90% 的书籍
- ArXiv [2.5%]: 处理 arXiv 上的 Latex 文件,以将科学数据添加到数据集中。去除第一节之前的内容以及参考文献部分,同时去除 .tex 文件中的注释,并展开了用户编写的内联定义和宏,以增加论文之间的一致性
- Stack Exchange [2%]: Stack Exchange 是一个高质量问题和答案的网站,涵盖了从计算机科学到化学等多种领域。数据集移除了文本中的 HTML 标签,并按分数(从高到低)对答案进行了排序
1.2、分词器
使用字节对编码(BPE)算法对数据进行分词。特别的,所有数字会拆分为单个数字,并在遇到未知的 UTF-8 字符时回退到使用字节进行分解。整个训练数据集在分词后大约包含 1.4 万亿个词元(token)。对于大多数训练数据,每个 token 在训练过程中仅使用一次,唯一的例外是维基百科和书籍领域数据进行了大约两轮的训练。
1.3、模型架构
LLaMA 网络基于 Transformer 架构,并在模型做了一些改进:
- 预标准化: 为了提高训练的稳定性,对每个 Transformer 子层的输入使用 RMSNorm(Root Mean Square Normalization)进行归一化,而不是对输出进行归一化
- SwiGLU 激活函数: 采用 SwiGLU(Swish-Gated Linear Unit)激活函数代替了 ReLU 非线性函数,使用的维度为 $\frac{2}{3} \times 4d$。SwiGLU 结合了 Swish 激活函数的非线性特性和门控机制,通过引入门控机制来控制信息的流动。相比于 ReLU,SwiGLU 在某些情况下能够更好地捕捉到复杂的模式,提高模型的表现力和性能
- 旋转位置嵌入: 在网络的每一层添加旋转位置嵌入(RoPE)来取代传统的绝对位置嵌入方法。RoPE 通过在复数域中旋转位置向量来编码序列中的位置信息,这使得模型能够更有效地捕捉相对位置的信息
RMSNorm
RMSNorm(Root Mean Square Normalization)对输入进行归一化时,不计算均值,而是基于 均方根(RMS) 对输入进行标准化。 假设输入是一个张量 $\mathbf{x} \in \mathbb{R}^{B \times D}$,其中 $B$ 是批大小(batch size),$D$ 是特征维度。
1、计算输入的每个维度的均方根值:
\[\text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{B} \sum_{i=1}^{B} x_i^2}\]2、然后,通过将输入特征除以均方根值来进行标准化:
\[\hat{x}_i = \frac{x_i}{\text{RMS}(\mathbf{x})}\]相比于层归一化(Layer Normalization)通过计算每个样本在所有特征上的均值和方差来进行归一化,RMSNorm 不进行去中心化操作,不需要计算均值只依赖于均方根,在计算上更加高效。
SwiGLU
SwiGLU(Switched Gated Linear Unit)在标准的 GLU 激活函数基础上引入了 “切换”(switched) 机制:
\[\text{SwiGLU}(x) = (x \odot \sigma(W_1 x + b_1)) + \gamma \odot (x \odot \sigma(W_2 x + b_2))\]其中,$x$ 是输入向量,$W_1, W_2$ 是两组线性变换的权重矩阵,$b_1, b_2$ 是对应的偏置项,$\sigma$ 表示 Sigmoid激活函数:$\sigma(z) = \frac{1}{1 + e^{-z}}$,$\odot$ 表示逐元素相乘(Hadamard积)。 对于每个线性变换结果,SwiGLU 使用 Sigmoid 激活函数生成一个介于 0 到 1 之间的值,作为 门控(gating) 系数,从而动态地调节每个元素的影响力。相比传统的 ReLU,SwiGLU 能有效避免死神经元问题(神经元在训练过程中由于梯度消失而无法更新)。此外,门控机制使得 SwiGLU 可以动态地选择不同特征进行激活,有效捕捉不同层次的语义信息
RoPE
RoPE(Rotary Positional Embedding)通过对
Query
和Key
向量施加特定的旋转变换操作来编码序列中每个位置的信息,而不再依赖于固定的正弦和余弦函数或直接学习位置向量。1、初步介绍:
设 $S_{N}=\lbrace w_{i} \rbrace _{i=1}^{N}$ 为一个包含 $N$ 个输入标记的序列,其中 $w_i$ 是第 $i^{th}$ 个元素。相应的词嵌入表示为 $E_N= \lbrace x_i \rbrace _{i=1}^{N}$,其中 $x_i \in R^{d}$ 是标记 $w_i$ 的 $d$ 维词嵌入向量,且未包含位置信息。自注意力首先将位置信息融入词嵌入,然后将其转换为查询、键和值表示:
\[\begin{align*} q_{m}&=f_{q}(x_{m},m)\\ k_{n}&=f_{k}(x_{n},n)\\ v_{n}&=f_{v}(x_{n},n), \end{align*} \tag {1}\]其中 $q_{m}, k_{n}$ 和 $v_{n}$ 分别通过 $f_{q}, f_{k}$ 和 $f_{v}$ 融入第 $m^{th}$ 和 $n^{th}$ 位置。然后使用查询和键值来计算注意力权重,而输出则作为值表示的加权和计算得出:
\[\begin{align*} a_{m, n}&=\frac{\exp\left(\frac{q_{m}^{\top} k_{n}}{\sqrt{d}}\right)}{\sum_{j=1}^{N}\exp\left(\frac{q_{m}^{\top} k_{j}}{\sqrt{d}}\right)}\\ o_{m}&=\sum_{n=1}^{N} a_{m, n} v_{n} \end{align*} \tag {2}\]2、绝对位置嵌入:
等式(1)的一个典型方法为:
\[f_{t: t\in\{q,k,v\}}\left(x_{i},i\right):=W_{t: t\in\{q,k,v\}}\left(x_{i}+p_{i}\right) \tag {3}\]其中 $p_{i}\in R^{d}$ 是一个依赖于标记 $x_{i}$ 的位置的 $d$ 维向量。Transformer 论文提出了使用正弦函数生成 $p_{i}$:
\[\begin{cases} p_{i,2t}&=\sin\left(k/ 10000^{2 t/ d}\right)\\ p_{i,2t+1}&=\cos\left(k/ 10000^{2 t/ d}\right) \end{cases} \tag {4}\]其中 $p_{i, 2 t}$ 是 $d$ 维向量 $p_{i}$ 的第 $2 t$ 个元素。与此不同的,RoPE 不是通过直接将位置添加到上下文表示中来结合相对位置信息,而是通过乘以正弦函数来结合相对位置信息。
3、二维旋转位置编码:
为了引入相对位置信息,需要查询 $q_{m}$ 和键 $k_{n}$ 的内积通过函数 $g$ 来形式化,该函数只取词嵌入 $x_{m}, x_{n}$ 和它们的相对位置 $m-n$ 作为输入变量(希望内积仅以相对形式编码位置信息):
\[\left\langle f_{q}\left(x_{m},m\right),f_{k}\left(x_{n},n\right)\right\rangle=g\left(x_{m},x_{n},m-n\right) \tag{5}\]最终目标是找到一个等效的编码机制使上述关系成立。
从词向量的维度为简单的二维情况开始,利用二维平面上向量的几何属性及其复数形式来证明等式(5)的一个解决方案:
\[\begin{align*} f_{q}\left(x_{m}, m\right)&=\left(W_{q} x_{m}\right) e^{i m\theta}\\ f_{k}\left(x_{n}, n\right)&=\left(W_{k} x_{n}\right) e^{i n\theta}\\ g\left(x_{m}, x_{n}, m-n\right)&=\operatorname{Re}\left[\left(W_{q} x_{m}\right)\left(W_{k} x_{n}\right)^{*} e^{i(m-n)\theta}\right] \end{align*} \tag{6}\]其中 $\operatorname{Re}[\cdot]$ 是复数的实部,$\left(W_{k} x_{n}\right)^{*}$ 表示 $\left(W_{k} x_{n}\right)$ 的共轭复数。$\theta \in R$ 是一个预设的非零常数。
进一步将 $f_{{q, k}}$ 写成乘法矩阵的形式:
\[f_{\{q, k\}}\left(x_{m}, m\right)=\left(\begin{array}{cc}\cos m\theta&-\sin m\theta\\ \sin m\theta&\cos m\theta\end{array}\right)\left(\begin{array}{cc}W_{\{q, k\}}^{(11)}& W_{\{q, k\}}^{(12)}\\ W_{\{q, k\}}^{(21)}& W_{\{q, k\}}^{(22)}\end{array}\right)\left(\begin{array}{c}x_{m}^{(1)}\\ x_{m}^{(2)}\end{array}\right) \tag{7}\]其中 $(x_{m}^{(1)}, x_{m}^{(2)})$ 是 $x_{m}$ 在二维坐标系中的表示。同样,$g$ 可以被视为一个矩阵,从而在二维情况下实现等式(5)的解决方案。结合相对位置编码是直接的:简单地通过其位置索引的角度倍数旋转仿射变换后的词嵌入向量,即
Query
向量乘以一个旋转矩阵。最终 $g\left(x_{m}, x_{n}, m-n\right)$ 可以表示如:
\[g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right)=\left(\begin{array}{ll} \boldsymbol{q}_m^{(1)} & \boldsymbol{q}_m^{(2)} \end{array}\right)\left(\begin{array}{cc} \cos ((m-n) \theta) & -\sin ((m-n) \theta) \\ \sin ((m-n) \theta) & \cos ((m-n) \theta) \end{array}\right)\binom{k_n^{(1)}}{k_n^{(2)}} \tag{8}\]4、多维一般形式:
为了将在二维情况下的结果推广到任何偶数维度的 $x_{i} \in R^{d}$,将 $d$ 维空间划分为 $d/2$ 个子空间,并利用内积的线性性质将它们组合起来,将 $f_{{q,k}}$ 转换为:
\[f_{\{q,k\}}(x_{m},m)=R_{\Theta,m}^{d}W_{\{q,k\}}x_{m} \tag{9}\]其中
\[R_{\Theta,m}^{d}=\begin{pmatrix}\cos m\theta_{1}&-\sin m\theta_{1}&0&0&\cdots&0&0\\ \sin m\theta_{1}&\cos m\theta_{1}&0&0&\cdots&0&0\\ 0&0&\cos m\theta_{2}&-\sin m\theta_{2}&\cdots&0&0\\ 0&0&\sin m\theta_{2}&\cos m\theta_{2}&\cdots&0&0\\ \vdots&\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\ 0&0&0&0&\cdots&\cos m\theta_{d/2}&-\sin m\theta_{d/2}\\ 0&0&0&0&\cdots&\sin m\theta_{d/2}&\cos m\theta_{d/2}\end{pmatrix} \tag{10}\]是具有预定义参数 $\Theta={\theta_{i}=10000^{-2(i-1)/d}, i \in [1,2,\ldots,d/2]}$ 的旋转矩阵。
将 RoPE 应用于自注意力机制中的方程(2):
\[q_{m}^{\intercal}k_{n}=(R_{\Theta,m}^{d} W_{q}x_{m})^{\intercal}(R_{\Theta,n}^{d} W_{k}x_{n})=x^{\intercal}W_{q}R_{\Theta,n-m}^{d}W_{k}x_{n} \tag{11}\]其中
\[R^{d}_{\Theta,n-m}=(R^{d}_{\Theta,m})^{\intercal}R^{d}_{\Theta,n} \tag{12}\]注意 $R^{d}_{\Theta}$ 是一个正交矩阵,这确保了在编码位置信息的过程中稳定性。
此外,由于 $R_{\Theta}^{d}$ 的稀疏性,直接应用矩阵乘法如方程(11)在计算上并不高效。通过将向量 $x$ 与一组预先计算的旋转系数相乘($\otimes$:逐位相乘),避免了直接计算旋转矩阵与向量的标准矩阵乘法,从而减少了计算复杂度:
\[R_{\Theta,m}^{d}x=\begin{pmatrix}x_{1}\\ x_{2}\\ x_{3}\\ x_{4}\\ \vdots\\ x_{d-1}\\ x_{d}\end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_{1}\\ \cos m\theta_{1}\\ \cos m\theta_{2}\\ \cos m\theta_{2}\\ \vdots\\ \cos m\theta_{d/2}\\ \cos m\theta_{d/2}\end{pmatrix}+\begin{pmatrix}-x_{2}\\ x_{1}\\ -x_{4}\\ x_{3}\\ \vdots\\ -x_{d}\\ x_{d-1}\end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_{1}\\ \sin m\theta_{1}\\ \sin m\theta_{2}\\ \sin m\theta_{2}\\ \vdots\\ \sin m\theta_{d/2}\\ \sin m\theta_{d/2}\end{pmatrix} \tag{13}\]与之前工作中采用的加法性质的位置编码方法相比,RoPE 的方法是乘法的。此外,RoPE 通过旋转矩阵乘积自然地结合了相对位置信息,而不是在应用自注意力时,改变加法位置编码扩展形式中的项。
RoPE 的图形说明如图所示,对于一个
token
序列:a. 每个词嵌入通过线性变换生成其对应的
Query
和Key
向量b. 根据
token
在序列中的位置,生成每个位置的旋转位置编码c. 对每个
token
的Query
和Key
向量,按照其元素 两两一组 应用旋转变换d. 对 Query 和 Key 向量计算内积,生成自注意力机制中的注意力权重
1.4、优化器
- 使用 AdamW 优化器进行训练,超参数设置为:$\beta_{1}=0.9,\beta_{2}=0.95$
- 使用余弦学习率调度,使得最终学习率等于最大学习率的 10%,权重衰减为 0.1,梯度裁剪为 1.0
- 使用 2000 个预热步长(warmup steps),并根据模型的大小调整学习率和批量大小
1.5、高效实现
为了提高模型的训练速度,LLaMA 做了如下优化:
- 使用高效的因果多头自注意力(causal multi-head attention)实现来减少内存使用和运行时间
- 使用了 Flash Attention 的反向传播,以此避免存储注意力权重,并且不计算由于语言建模任务的因果特性而被屏蔽的键/查询分数
- 通过检查点技术(checkpointing)减少了在反向传播期间重新计算的激活量,即保存了计算开销较大的激活(如线性层的输出)。该方法通过手动实现 Transformer 层的反向传播函数来实现,而不是依赖 PyTorch 的自动梯度计算(autograd)
- 通过使用模型并行和序列并行来减少模型的内存使用
- 尽可能地将激活计算与 GPU 间的网络通信(all_reduce 操作)重叠,以进一步提高效率
Flash Attention
传统自注意力的计算复杂度为 $O(n^2 \cdot d)$,且自注意力的常规实现会导致频繁的内存读取和写入,反向传播需要存储大量中间结果。
0、标准注意力实现:
给定输入序列 $Q$, $K$, $V \in \mathbb{R}^{N \times d}$,其中 $N$ 是序列长度,$d$ 是头维度,注意力输出 $O \in \mathbb{R}^{N \times d}$:
\[S = QK^T \in \mathbb{R}^{N \times N}, \quad P = \text{softmax}(S) \in \mathbb{R}^{N \times N}, \quad O = PV \in \mathbb{R}^{N \times d},\]其中 softmax 操作是按行(row-wise)进行的(矩阵每一行代表了不同位置之间的相似度)。
标准注意力实现算法:
矩阵 $Q $, $K $, $V \in \mathbb{R}^{N \times d} $ 存储在 HBM(高速内存)中:
- 按块从 HBM 加载 $Q$ 和 $K$,计算 $S = QK^T$,然后将 $S$ 写回 HBM
- 从 HBM 读取 $S$,计算 $P = \text{softmax}(S)$,然后将 $P$ 写回 HBM
- 按块从 HBM 加载 $P$ 和 $V$,计算 $O = PV$,然后将 $O$ 写回 HBM
- 返回 $O$
标准的注意力机制实现会将矩阵 $S$ 和 $P$ 物化到高带宽内存(HBM)中,这需要 $O(N^2)$ 的内存。通常序列长度 $N \gg d$(如 GPT-2 中,$N = 1024$,$d = 64$)。 这种问题在注意力矩阵上应用其他逐元素操作(如对 $S$ 进行掩码处理或对 $P$ 应用 Dropout)时会更加严重。
FlashAttention 作为 IO 感知的注意力算法,以解决 Transformer 在长序列上的计算和内存问题。给定输入 $Q$, $K$, $V \in \mathbb{R}^{N \times d}$ 存储在高带宽内存(HBM)中,旨在计算注意力输出 $O \in \mathbb{R}^{N \times d}$ 并将其写回 HBM,目标将对 HBM 的访问量降低到次二次复杂度(即低于 $O(N^2)$)。
FlashAttention 核心思想是将输入 $Q$、$K$、$V$ 分成多个块,将它们从慢速的 HBM 加载到快速的片上存储(SRAM)中,然后对这些块分别计算注意力输出。通过在每个块的输出上应用归一化因子并将它们累加,最终得到结果。
1、块处理(Tiling):
由于 Softmax 操作会将 $K$ 的列进行耦合,因此通过缩放来分解大规模的 softmax 操作。为了数值稳定性,FlashAttention 采用 safe softmax,向量 $x \in \mathbb{R}^B$ 的 softmax 计算方式如下:
\[m(x) := \max_i x_i\] \[f(x) := \left( e^{x_1 - m(x)}, e^{x_2 - m(x)}, \dots, e^{x_B - m(x)} \right)\] \[\ell(x) := \sum_i f(x)_i\] \[\text{softmax}(x) := \frac{f(x)}{\ell(x)}\]同理,向量 $x = [x^{(1)}; x^{(2)}] \in \mathbb{R}^{2B}$ 的 softmax 计算可分解过程如下:
\[m(x) = m([x^{(1)}; x^{(2)}]) = \max(m(x^{(1)}), m(x^{(2)}))\] \[f(x) = \left( e^{m(x^{(1)})} f(x^{(1)}), e^{m(x^{(2)})} f(x^{(2)}) \right)\] \[\ell(x) = \ell(x^{(1)}) + \ell(x^{(2)})\] \[\text{softmax}(x) = \frac{f(x)}{\ell(x)}\]因此,跟踪一些额外的统计信息(如 $m(x)$ 和 $\ell(x)$),就可以逐块计算 softmax,最后合并结果。
2、重计算(Recomputation):
反向传播通常需要矩阵 $S$ 和 $P \in \mathbb{R}^{N \times N}$,以计算相对于 $Q$、$K$、$V$ 的梯度。为了避免为反向传播过程存储 $O^{N^2}$ 中间值,通过存储输出 $O$ 和 softmax 标准化统计信息 $(m, \ell)$,可以在反向传播中从 $Q$、$K$、$V$ 的块中重新计算注意力矩阵 $S$ 和 $P$(selective gradient checkpointing)。即使有更多的 FLOPs(浮点运算),由于减少了对 HBM 的访问,重新计算加速了反向传播过程。
左图:FlashAttention 使用分块技术来防止在(相对较慢的)GPU HBM 中生成大型的 $N \times N$ 注意力矩阵(虚线框)。在外循环中(红色箭头),FlashAttention 会遍历 $K$ 和 $V$ 矩阵的块,并将它们加载到快速的片上 SRAM 中。在每个块内,FlashAttention 会遍历 $Q$ 矩阵的块(蓝色箭头),将其加载到 SRAM 中,并将注意力计算的输出写回 HBM。
右图:在 GPT-2 上,FlashAttention 相较于 PyTorch 实现的注意力计算获得了加速。FlashAttention 不会将大型的 $N \times N$ 注意力矩阵读写到 HBM,从而在注意力计算上实现了 7.6 倍的加速。
2、LLaMA 2
LLaMA 2 发布于 2023 年 7 月,是 Llama 1 的升级版。LLaMA 2 将预训练语料库的规模扩大了 40%,把模型的上下文长度翻倍,并采用了分组查询注意力机制(grouped-query attention)。LLaMA 2 系列包括 3 个参数规模版本:7B、13B 和 70B。除了基础模型,同时发布的还有针对对话使用场景优化的 Llama 2 微调版本。
2.1、预训练
LLaMA 2 基于 2 万亿词元的数据进行预训练,数据在最具事实性的数据源进行上采样以增长知识并减少幻觉。
与 LLaMA 1 的大部分预训练设置和模型架构一样,LLaMA 2 使用标准的 Transformer 架构,基于 RMSNorm 的预归一化,使用 SwiGLU 激活函数,以及旋转位置嵌入(RoPE)。与 LLaMA 1 在架构上的主要差异包括增加上下文长度(从 2048 个 token 扩展到 4096 个 token)和采用分组查询注意力机制(GQA)。
- 超参数:
- LLaMA 2 使用 AdamW 优化器进行训练( β1 = 0.9,β2 = 0.95,eps = 10⁻⁵),并使用余弦学习率调度,初始阶段进行 2000 步的预热,并将最终学习率衰减到峰值学习率的 10%。我们使用 0.1 的权重衰减和 1.0 的梯度裁剪
- 分词器:
- 使用与 LLaMA 1 相同的字节对编码(BPE)算法分词器。与 LLaMA 1 一样将所有数字拆分为单个数字,并使用字节来分解未知的 UTF-8 字符。总词汇表大小为 32k 个词汇单元。
分组查询注意力(Grouped-Query Attention)
自回归解码的标准做法是缓存序列中前几个词的键(K)和值(V)对,以加速注意力计算。然而,随着上下文窗口或批处理大小的增加,多头注意力(MHA)模型中与 KV 缓存大小相关的内存成本显著增长。对于较大的模型,当 KV 缓存大小成为瓶颈时,可以在多个头之间共享键和值的投影,而不会导致性能显著下降。在实现上,可以选择使用单一 KV 投影的多查询格式(MQA),或者具有 8 个 KV 投影的分组查询注意力(GQA)。34B 和 70B 的 LLaMA 2 模型使用 GQA。
1、增量训练(Uptraining)
从多头模型(Multi-Head Model)生成多查询模型(Multi-Query Model)的过程:1. 将多头检查点转换为多查询检查点,键(Key)和值(Value)头的投影矩阵通过均值池化(Mean Pooling)合并为单个投影矩阵;2、通过额外的预训练让模型适应新的结构
2、分组查询注意力(Grouped-Query Attention)
分组查询注意力将查询头(Query Heads)分为 $G$ 个组,每组共享一个键头(Key Head)和值头(Value Head)。
- GQA-G: 表示具有 $G$ 个组的 分组查询注意力
- GQA-1: 只有一个组,因此只有一个键头和值头,相当于 多查询注意力(Multi-Query Attention, MQA)
- GQA-H: 组数等于头数,相当于 多头注意力(Multi-Head Attention, MHA)
在将多头检查点(MHA Checkpoint)转换为 GQA 检查点(GQA Checkpoint)的过程中,通过对每组内的所有原始头进行均值池化(Mean Pooling),构造每组的键头和值头。
2.2、监督微调(Supervised Fine-Tuning,SFT)
SFT 专注于收集数千个高质量的 SFT 数据示例,通过舍弃第三方数据集中的数百万个示例,转而使用基于供应商标注工作得到的数量较少但质量更高的示例,使结果有了显著提升。大约数万条 SFT 标注就足以实现高质量的结果,在总共收集了 27540 条标注后停止了 SFT 标注工作。
此外,不同的标注平台和供应商会导致下游模型性能出现显著差异,这凸显了数据检查的重要性。为了验证数据质量,通过人工审查 180 个示例将人工提供的标注与模型生成的样本进行对比。从最终的 SFT 模型中采样得到的输出,往往与人工标注者手写的 SFT 数据不相上下,这表明可以重新规划优先级,将更多的标注精力投入到基于偏好的注释上以支持 RLHF。
监督微调使用余弦学习率调度,初始学习率为 $2 \times 10^{-5}$,权重衰减为 0.1,批量大小为 64,序列长度为 4096 个 token。
在微调过程中,每个样本由一个提示和一个答案组成。为了确保模型的序列长度正确填充,将训练集中的所有提示和答案连接起来,并使用一个特殊标记来分隔提示和答案段。微调采用自回归目标,并将来自用户提示的标记的损失归零,因此只对答案 token 进行反向传播。最后,模型进行 2 个 epoch 的微调。
2.3、人类反馈强化学习(RLHF)
2.3.1、人类偏好数据收集
方案采用二元比较协议以最大化收集到的提示多样性。标注过程要求标注者首先编写一个提示,然后根据提供的标准在两个采样的模型响应之间进行选择。为了最大化多样性,给定提示的两个响应分别从两个不同的模型变体中采样,并且调整温度超参数。此外,除了让参与者做出强制选择外,还要求标注者标注他们选择的响应相对于替代响应的偏好程度:“显著更好”,“更好”,“稍微更好”,“几乎没有区别”/“不确定”。
偏好标注收集主要关注有用性和安全性,例如,“提供制造炸弹的详细说明” 可能被认为是有用的,但是是不安全的。
最终,依据特定准则收集了一个超 100 万条二元比较的大型数据集,与现有的开源数据集相比,Meta 奖励建模数据平均对话轮次更多,长度也更长。
2.3.2、奖励建模
奖励模型采用模型响应及其对应的提示(包括前几轮对话的上下文)作为输入,并输出一个标量分数,以指示模型生成的质量(例如,有用性和安全性)。利用这样的响应分数作为奖励,可以在强化学习与人类反馈(RLHF)期间优化 LLaMA 2-CHAT。
有用性和安全性有时会相互权衡,单一的奖励模型很难在两者上都表现良好。因此训练了两个独立的奖励模型,一个针对有用性(Helpfulness RM),另一个针对安全性(Safety RM)。
从预训练的聊天模型初始化奖励模型以确保两个模型都从预训练中获取的知识中受益,以此避免两个模型出现信息不匹配的情况,否则可能会导致对幻觉内容的偏好。奖励模型的架构和超参数与预训练语言模型完全相同,只是将用于预测下一个词元的分类头替换为用于输出标量奖励的回归头。
为了训练奖励模型,将收集的成对人类偏好数据转换为二元排名标签格式(即选择和拒绝),并强制选择的响应得分高于其对应的拒绝响应,二元排名损失:
\[L_{\text{ranking}} = -\log(\sigma(r_\theta(x, y_c) - r_\theta(x, y_r)))\]其中,$r_\theta(x, y)$ 是模型权重 $\theta$ 下,给定提示 $x$ 和完成 $y$ 的标量评分输出。 $y_c$ 是标注者选择的偏好响应,$y_r$ 是被拒绝的对应响应。
考虑到偏好评分被分解为四个等级(例如,“显著更好”),在损失函数中进一步添加一个边际组件:
\[L_{\text{ranking}} = -\log(\sigma(r_\theta(x, y_c) - r_\theta(x, y_r) - m(r)))\]其中,边际 $m(r)$ 是偏好评分的离散函数。通常,对于响应差异较大的对,使用较大的边际;对于差异较小的对,使用较小的边际。边际组件可以提高有用性奖励模型的准确性,尤其是在两个响应更容易区分的样本上。
奖励模型在训练数据上进行 1 个 epoch 的训练,训练时间过长可能会导致过拟合。优化器的参数与基础模型相同,学习率的设置如下:
- 对于具有 70B 参数的 Llama 2-Chat,最大学习率为 5 × 10⁻⁶
- 对于其他模型,最大学习率为 1 × 10⁻⁵
学习率按照余弦学习率调度逐步下降,最低降至最大学习率的 10%。同时,训练的预热阶段(warm-up)为总步数的 3%,但至少为 5 步。训练有效批量大小固定为 512 对(pairs),即每批 1024 行数据。
在每一批用于奖励建模的人工偏好标注数据中,保留 1000 个样本作为测试集来评估模型。从缩放趋势上看,较大的模型在相似的数据量下可以获得更高性能。
2.3.3、迭代微调
LLaMA 2 探索了两种主要的算法进行 RLHF 微调:
- 近端策略优化(Proximal Policy Optimization, PPO): RLHF 文献中的标准算法
- 拒绝采样微调(Rejection Sampling fine-tuning): 通过从模型中采样 $K$ 个输出,并根据奖励函数选择最佳候选输出。重排序策略将奖励视为一种能量函数,在此基础上,利用选出的最佳输出对模型进行梯度更新。对于每个提示(Prompt),选择获得最高奖励分数的样本作为新的黄金标准(gold standard)。最后在新的排名样本集合上微调模型,从而强化奖励信号
两种强化学习算法的主要差异在于:
- 广度(Breadth): 在拒绝采样中,模型为每个提示探索 $K$ 个样本;而在 PPO 中,每个提示只生成一个样本
- 深度(Depth): 在 PPO 训练中,第 $t$ 步的样本由第 $t-1$ 步更新后的模型策略生成,并基于上一步的梯度更新。在拒绝采样微调中,在模型的初始策略下采样所有输出,收集新数据集后,按照 SFT 的方式进行微调。由于采用了迭代的模型更新方式,两种 RL 算法之间的根本差异并不显著
LLaMA 2 将两种方法结合起来:先在拒绝采样得到的检查点上应用 PPO,再在此基础上重新采样。
1、拒绝采样
仅在最大的 70B Llama 2-Chat 模型上执行拒绝采样。所有较小的模型则在从大模型拒绝采样得到的数据上进行微调,从而将大模型的能力蒸馏到小模型中。
在每个迭代阶段,最新模型为每个提示(prompt)采样 $K$ 个回答,并使用最佳奖励模型对每个样本进行评分,然后为每个提示选择得分最高的回答。
早期实验中,若策略仅限于从前一轮采样集合中选择答案,则会在某些能力上退化,后续迭代整合了来自所有前期迭代的表现最佳样本。
2、PPO
奖励模型被用作真实奖励函数(即人类偏好)的估计值,预训练语言模型作为需要优化的策略。目标函数:
\[\underset{\pi}{\text{arg max}} \, \mathbb{E}_{p \sim D, g \sim \pi} [R(g \mid p)]\]从数据集 $D$ 中采样提示 $p$,从策略 $\pi$ 中生成 $g$,并使用 PPO 算法及其损失函数来逐步优化策略。
在优化过程中,使用的最终奖励函数如下:
\[R(g \mid p) = \tilde{R}_c(g \mid p) - \beta D_{\text{KL}}(\pi_\theta(g \mid p) \| \pi_0(g \mid p))\]其中包含一项偏离初始策略 $\pi_0$ 的惩罚项,此约束有助于训练的稳定性,并减少 “奖励作弊”(模型在奖励模型上表现良好但在人类评价中得分低)的现象。
将 $R_c$ 定义为安全性奖励模型 ($R_s$) 和帮助性奖励模型 ($R_h$) 的分段组合:
\[R_c(g \mid p) = \begin{cases} R_s(g \mid p) & \text{if IS_SAFETY}(p) \text{ or } R_s(g \mid p) < 0.15 \\ R_h(g \mid p) & \text{otherwise} \end{cases}\]为了提高稳定性和平衡 KL 惩罚项 $\beta$,对线性得分进行了 “白化” 处理,通过反向应用 sigmoid 函数的 logit 函数实现:
\[\tilde{R}_c(g \mid p) = \text{WHITEN}(\text{LOGIT}(R_c(g \mid p)))\]- 模型优化器为 AdamW,其中 $\beta_1 = 0.9, \beta_2 = 0.95, \epsilon = 10^{-5}$
- 采用固定为 $10^{-6}$ 的学习率,0.1 的权重衰减,以及 1.0 的梯度裁剪
- 每次 PPO 迭代中,采用 512 的 batch size,0.2 的 PPO 剪裁阈值,mini-batch 大小为 64,每个 mini-batch 进行一次梯度更新
- 7B 和 13B 模型的 KL 惩罚项为 $\beta = 0.01$,34B 和 70B 模型的 KL 惩罚项为 $\beta = 0.005$
所有模型进行 200 到 400 次迭代训练,并通过对留出的提示(Prompt)进行评估来决定是否提前终止训练。70B 模型 PPO 的每次迭代平均耗时约 330 秒。为加速大批量训练,采用了完全分片数据并行(Fully Sharded Data Parallel,FSDP),在进行 $O(1)$ 前向或后向传播时效果显著,但在生成阶段,即使使用较大的批量和 KV 缓存,仍会导致约 20 倍的速度下降。通过在生成前将模型权重整合到每个节点,并在生成后释放内存,然后继续训练循环的其余部分,从而缓解了这一问题。
2.4、多轮一致性的系统消息
在对话场景中,某些指令应适用于整个对话过程,例如要求回答简洁,或 “扮演” 某位公众人物。当向 Llama 2-Chat 提供这样的指令时,模型的后续回复应始终遵守这些约束。然而,初始的 RLHF 模型在多轮对话后往往会忘记最初的指令。
为了应对这些局限性,LLaMA 2 提出了 Ghost Attention(GAtt) 方法。这是一种受到 Context Distillation(上下文蒸馏) 启发的简单技术,通过 “调整微调数据” 帮助注意力机制在多轮对话中保持一致性。
假设一个两人对话的数据集,该数据集由一系列消息组成:$[u_1, a_1, \dots, u_n, a_n]$。其中,$u_n$ 和 $a_n$ 分别表示第 $n$ 轮对话中用户和助手的消息。接下来,定义一个应贯穿整个对话的指令 $inst$,例如 “act as” (扮演某人),将这个指令合成到对话中所有的用户消息上。
然后,使用最新的 RLHF 模型对这些合成数据进行采样(对多轮对话中每个用户的消息生成多次的结果)。此时可以得到一组上下文对话和用于微调模型的样本,这一过程类似于拒绝采样(使用 RM 选择最好的结果)。
针对合成的数据集去训练模型时(SFT),可以仅在第一轮对话中添加指令,其余轮次不添加。但会导致训练时系统消息(即助手的中间回复)和采样结果之间出现不匹配问题。为了修复这一问题,简单地将先前所有对话轮次(包括助手消息)的所有词元损失设置为 0,避免它们影响训练。
3、LLaMA 3
LLaMA 3 发布于 2024 年 4 月,参数量分别有 80 亿、700 亿 和 4050 亿,模型基于超过 15T 的数据进行训练,数量是 LLaMa 2 的 7 倍,代码量多 4 倍,并支持 8K 上下文长度,是 Llama 2 容量的 2 倍。Llama 3.1 的上下文窗口大小达到了 148K。
3.1、预训练
3.1.1、网页数据整理
预训练数据大部分来自于网页数据,清洗过程如下:
- PII 和安全过滤
- 旨在从网站中去除不安全内容或个人身份信息(PII)数据、有害域名,以及已知包含成人内容的域名
- 文本提取与清理
- 构建解析器提取 HTML 内容,去除模板内容并保留有效内容的精准度,保留数学和代码内容。此外,删除了所有的 Markdown 标记
- 去重
- URL 级去重
- 保留与每个 URL 对应页面的最新版本
- 文档级去重
- 整个数据集上执行了全局 MinHash 去重,以去除近重复文档
- 行级去重
- 执行类似 ccNet的行级去重。移除在每 3000 万份文档的桶中出现超过 6 次的行
- URL 级去重
- 启发式过滤
- 使用重复 n-gram 覆盖率去除由重复内容(如日志或错误信息)组成的行(这些行可能非常长且唯一,因此无法通过行去重来过滤)
- 使用 “脏词” 计数过滤掉未包含在域名屏蔽列表中的成人网站
- 使用基于 token 分布的 Kullback-Leibler 散度,与训练语料库分布相比,筛选出那些包含过多异常词元的文档
- 基于模型的质量过滤
- 使用快速分类器(如 fasttext)识别给定文本是否会被维基百科引用
- 使用基于 DistilRoberta 的分类器为每个文档生成质量分数(基于 Llama 2 的预测结果进行训练)
- 代码与推理数据
- 构建基于 DistilRoberta 的代码分类器和推理分类器(在由 Llama 2 标注的网页数据上进行训练),通过提示调优(prompt tuning)以针对包含数学推导、STEM 领域推理以及与自然语言交织的代码的网页
- 多语言数据
- 使用基于 fasttext 的语言识别模型,将文档分类为 176 种语言
- 在每种语言的数据中执行文档级和行级的去重
- 应用语言特定的启发式方法和基于模型的过滤器,去除低质量文档
此外,使用基于 LLaMA 2 的多语言分类器对多语言文档进行质量排序,以确保高质量内容优先。
3.1.2、数据混合
通过知识分类和缩放定律实验,精确确定不同数据源在预训练数据中的比例:
- 知识分类:通过分类器对网页数据中包含的信息类型进行分类,使用该分类器对在网络上出现频率过高的数据类别(例如艺术和娱乐类)进行下采样。
- 数据组合的缩放定律:通过缩放定律实验确定最佳的数据组合。即,在一种数据组合上训练几个小型模型,然后利用这些结果预测大型模型在该数据组合上的性能。针对不同的数据组合多次重复这个过程,以选出新的数据组合候选方案。随后,在这个候选数据组合上训练一个更大的模型,并在几个关键基准测试上评估该模型的性能
- 数据组合总结:最终的数据组合中有大约 50% 为通用知识词元、25% 为数学和推理相关的词元、17% 为代码相关的词元、8% 为多语言相关的词元
3.1.3、退火数据
对少量高质量代码和数学数据进行退火能够提升预训练模型在关键基准上的表现。具体做法是通过使用一个数据混合来进行退火,其中会对特定领域中的高质量数据进行上采样。在大型模型(如 LLaMA 3 405B)上,退火的效果较为有限,因为这些模型已经具备了强大的推理能力和良好的上下文学习表现。
3.2、模型架构
LLaMA 3 与 LLaMA 1 和 LLaMA 2 基本相同,采用标准的 Transformer 架构,性能提升主要得益于数据质量与多样性的改善以及训练规模的加大。
与 LLaMA 2 相比,LLaMA 3 做了以下小幅修改:
- 使用了分组查询注意力(GQA)与 8 个键值头,以提高推理速度并在解码过程中减少键值缓存的大小
- 使用了防止同一序列内不同文档之间产生自注意力的注意力编码。这一改变在标准预训练中影响有限,但在对极长序列进行持续预训练时显得尤为重要
- 使用包含 128K 词元的词表。词元词表将来自 tiktoken3 分词器的 100K 词元与额外的 28K 词元相结合,以更好地支持非英语语言。与 Llama 2 分词器相比,新分词器将英语数据样本的压缩率从每个词元 3.17 个字符提高到 3.94 个字符,模型在相同的训练计算量下能够 “读取” 更多文本。此外,从特定非英语语言中添加 28K 词元,既提高了压缩率,又提升了下游任务的性能,且对英语分词没有影响
- 将旋转位置嵌入(RoPE)的基础频率超参数提高到 500,000,使得模型能够更好地支持更长的上下文
LLaMA 3 405B 采用具有 126 层的架构,词元表示维度为 16,384,有 128 个注意力头。根据数据上的缩放定律,对于 $3.8 \times 10^{25}$ 次浮点运算(FLOPs)的训练预算而言,这样的架构使得模型规模在计算上近乎达到最优。
3.2.1、缩放定律
通过制定缩放定律以便在给定预训练计算资源预算的情况下,确定旗舰模型的最优模型规模。预测旗舰模型在下游基准任务上的性能是一个重大挑战:
- 现有的缩放定律通常仅预测下一个词元的预测损失,而非特定基准任务的性能
- 缩放定律可能存在误差且不可靠,因为它们是基于使用较小计算资源预算进行的预训练得出的
采用两阶段方法来应对挑战:
- 首先建立计算最优模型在下游任务上的负对数似然与训练所需的浮点运算次数(FLOPs)之间的相关性
- 利用缩放定律模型以及使用更高计算 FLOPs 训练的旧模型,将下游任务的负对数似然与任务准确率相关联(在这一步中借助了 Llama 2 系列模型)
这些实验产生了下图左图中的 IsoFLOPs 曲线。使用二次多项式拟合测得的损失值,并确定每个抛物线的最小值。最终将抛物线的最小值称为在相应预训练计算预算下的计算最优模型。通过这种方式确定的计算最优模型,来预测特定计算资源预算下的最优训练词元数量。假设计算资源预算 $C$ 与最优训练词元数量 $N^*(C)$ 之间存在幂律关系:
\[N^*(C) = AC^{\alpha}\]使用左图中的数据来拟合 $A$ 和 $\alpha$,相应的拟合结果如右图所示。将由此得到的缩放定律外推到 $3.8×10^{25}$ 次浮点运算,表明应使用 $16.55T$ 个词元训练一个具有 $4020$ 亿参数的模型。
3.3、基础设施、扩展和效率
3.3.1、训练基础设施
- 计算:
- LLaMA 3 405B 模型在多达 16K 块 H100 GPU 上进行训练,每块 GPU 的热设计功耗为 700 瓦,配备 80GB 的 HBM3 显存。每台服务器配备 8 块 GPU 和 2 个 CPU,8 块 GPU 通过 NVLink 互联
- 存储:
- 由 7500 台配备固态硬盘的服务器提供 240PB 的存储容量,支持 2TB/s 的可持续吞吐量和 7TB/s 的峰值吞吐量。一个主要挑战是应对高度突发的检查点写入操作,这种操作会在短时间内使存储架构饱和。检查点操作会保存每块 GPU 的模型状态,每块 GPU 的保存数据量从 1MB 到 4GB 不等,用于恢复和调试。目标是在检查点操作期间尽量减少 GPU 的暂停时间,并增加检查点频率,以减少恢复后损失的工作量
- 网络:
- LLaMA 3 405B 使用基于 Arista 7800 和 Minipack2 开放计算项目 4 OCP 机架交换机的融合以太网远程直接内存访问(RDMA over Converged Ethernet,RoCE)架构。Llama 3 系列中的较小模型则使用英伟达 Quantum2 InfiniBand 架构进行训练。RoCE 和 InfiniBand 集群均利用 GPU 之间 400Gbps 的互联带宽
- 网络拓扑: 基于 RoCE 的 AI 集群由 24000 块 GPU 通过三层 Clos 网络连接而成)。在底层,每个机架容纳 16 块 GPU,分布在两台服务器中,并通过单个 Minipack2 机架顶(ToR)交换机连接。在中间层,192 个这样的机架通过集群交换机连接,形成一个拥有 3072 块 GPU 且具备完全二分带宽的计算单元(pod),确保无带宽超额预订。在顶层,同一数据中心大楼内的 8 个这样的计算单元通过聚合交换机连接,形成一个拥有 24000 块 GPU 的集群
- 负载均衡: 集合通信库在两块 GPU 之间创建 16 条网络流,而非仅一条,从而减少每条流的流量,并为负载均衡提供更多的流。E-ECMP 协议通过对数据包 RoCE 头部附加字段进行哈希处理,有效地在不同网络路径上平衡这 16 条流
- 拥塞控制: 在骨干网中使用深度缓存交换机以应对因集合通信模式导致的瞬时拥塞和缓冲,以限制由慢速服务器引起的持续拥塞和网络背压的影响。通过 E-ECMP 实现的更好的负载均衡显著降低了拥塞的可能性。无需使用诸如 DCQCN 等传统拥塞控制方法
- LLaMA 3 405B 使用基于 Arista 7800 和 Minipack2 开放计算项目 4 OCP 机架交换机的融合以太网远程直接内存访问(RDMA over Converged Ethernet,RoCE)架构。Llama 3 系列中的较小模型则使用英伟达 Quantum2 InfiniBand 架构进行训练。RoCE 和 InfiniBand 集群均利用 GPU 之间 400Gbps 的互联带宽
3.3.2、模型扩展的并行性
为了扩展最大模型的训练,使用四种不同类型的并行性方法的组合来分割模型。将计算分布到多个 GPU 上,并确保每个 GPU 的模型参数、优化器状态、梯度和激活适合其高带宽内存(HBM)。如图所示,结合了张量并行性(TP)、流水线并行性(PP)、上下文并行性(CP)和数据并行性(DP)。
- 张量并行性: 将单个权重张量分割成多个块,分布在不同的设备上
- 流水线并行: 按层将模型垂直划分为多个阶段,以便不同设备可以并行处理完整模型流水线的不同阶段
- 上下文并行性: 将输入上下文分成多个片段,减少对非常长序列长度输入的内存瓶颈
- 完全分片的数据并行性(FSDP): 在实现数据并行性的同时对模型、优化器和梯度进行分片,在多个GPU上并行处理数据,并在每个训练步骤后进行同步。Llama 3 的优化器状态和梯度使用 FSDP 进行分片,但对于模型分片,在前向计算后不进行重新分片,以避免在反向传播过程中出现额外的全聚集通信
1、流水线并行性改进
其中流水线并行性存在以下几个挑战:
- 批量大小限制: 每个 GPU 支持的批量大小要求能被流水线阶段的数量整除。以上图为例,流水线并行的深度优先调度需要 $N=PP=4$,而广度优先调度需要 $N=M$,其中 $M$ 是微批次的总数,$N$ 是同一阶段的前向或后向的连续微批次的数量。然而,预训练通常需要灵活调整批量大小
- 内存不平衡: 流水线并行导致资源消耗不平衡。由于嵌入和预热微批次,第一阶段消耗更多内存
- 计算不平衡: 在模型的最后一层之后,需要计算输出和损失,该阶段成为执行延迟的瓶颈
为了解决这些问题,修改了流水线调度。如下图所示,允许灵活设置 $N$(图中示例下 $N=5$),这样每个批次可以运行任意数量的微批次。如此能够:
- 当在大规模场景下存在批次大小限制时,运行比阶段数更少的微批次
- 运行更多的微批次以隐藏点对点通信,在深度优先调度(DFS)和广度优先调度(BFS)之间找到平衡点,从而实现最佳的通信和内存效率
为了平衡流水线,分别从第一阶段和最后阶段各减少一个 Transformer 层,即第一阶段的第一个模型块仅包含嵌入层,而最后阶段的最后一个模型块仅包含输出投影和损失计算。为减少流水线气泡,在一个流水线级别上采用了具有 $V$ 个流水线阶段的交错调度,总体流水线气泡率为 $\frac{PP - 1}{V * M}$。在流水线并行(PP)中采用异步点对点通信,如此显著加快训练速度,特别是在文档掩码引入额外计算不平衡的情况下。为降低内存成本,主动释放那些未来计算不再使用的张量,包括每个流水线阶段的输入和输出张量。通过这些优化,可以在不使用激活检查点的情况下,对 Llama 3 进行 8K 词元序列的预训练。
2、长序列的上下文并行性
LLaMA 3 利用 上下文并行(CP) 来提高上下文长度扩展时的内存使用效率,并支持对长度达 12.8 万的超长序列进行训练。在上下文并行中,沿着序列维度进行划分,具体而言,将输入序列划分为 $2 \times CP$ 个块,这样每个 $CP$ 级别接收两个块以实现更好的负载均衡。第 $i$ 个 $CP$ 级别会同时接收第 $i$ 个和第 $(2 \times CP - 1 - i)$ 个块。
与现有的以环形结构重叠通信和计算的上下文并行实现方式不同,LLaMA 3 上下文并行实现采用基于全收集的方法。即,首先全收集键($K$)和值($V$)张量,然后为本地查询($Q$)张量块计算注意力输出。尽管全收集通信延迟处于关键路径上,仍采用这种方法的原因为:(1)基于全收集的上下文并行注意力更容易且更灵活地支持不同类型的注意力掩码,如文档掩码;(2)由于使用了分组查询注意力(GQA),所通信的 $K$ 和 $V$ 张量比 $Q$ 张量小得多,因此暴露的全收集延迟较小。因此,注意力计算的时间复杂度比全收集高一个数量级(全因果掩码下分别为 $O (S^2)$ 和 $O (S)$,其中 $S$ 表示序列长度),使得全收集的开销可以忽略不计。
3、网络感知的并行配置
并行维度的顺序 $[TP, CP, PP, DP]$ 是针对网络通信进行优化的。最内层的并行需要最高的网络带宽和最低的延迟,因此通常限制在同一服务器内。最外层的并行可能会跨越多跳网络,应能容忍较高的网络延迟。因此,根据对网络带宽和延迟的要求,按 $[TP, CP, PP, DP]$ 的顺序安排并行维度。数据并行(DP,即全分片数据并行 FSDP)是最外层的并行,因为它可以通过异步预取分片的模型权重和规约梯度来容忍较长的网络延迟。
4、数值稳定性
为确保训练收敛,LLaMA 3 在多个微批次的反向计算过程中使用 FP32 梯度累积,并在 FSDP 的数据并行工作节点之间对 FP32 梯度进行归约-散射(reduce-scatter)。对于在前向计算中多次使用的中间张量(如视觉编码器输出),其反向梯度也以 FP32 格式进行累积。
3.4、训练方案
3.4.1、初始预训练
LLaMA 3 405B 使用 AdamW 优化器对进行预训练,峰值学习率设为 $8 \times 10^{-5}$,进行 8000 步的线性预热,通过余弦学习率在 120万 步内将学习率衰减至 $8 \times 10^{-7}$。训练初期使用较小的批次大小以提高训练稳定性,随后增大批次大小以提升效率。 初始批次大小为 400万 个词元,序列长度为 4096;在预训练 2.52亿 个词元后,将这些值翻倍,即批次大小变为 800万 个词元,序列长度变为 8192;在预训练 2.87万亿 个词元后,再次将批次大小翻倍至 1600万 个词元。训练过程中损失值波动较小,无需进行干预来纠正模型训练的发散问题。
为提高 LLaMA 3 的多语言处理能力,在预训练期间增加了非英语数据的占比;为提升模型的数学推理能力,对数学数据进行了上采样;在预训练后期,加入了更多近期的网页数据,以更新模型的知识截止时间;对于质量较低的预训练数据子集进行了下采样。
3.4.2、长上下文预训练
在预训练的最后阶段使用长序列进行训练,以使模型支持最长达 12.8万 个词元的上下文窗口。前期不使用长序列进行训练是因为自注意力层的计算量会随着序列长度的增加呈二次方增长。逐步增加支持的上下文长度,并进行预训练,直至模型成功适应增加后的上下文长度。通过两个指标来评估模型是否成功适应:(1)模型在短上下文评估中的表现是否完全恢复;(2)模型是否能完美解决该长度下的 “大海捞针” 任务。在 LLaMA 3 405B 的预训练中,分六个阶段逐步增加上下文长度,从最初的 8000 个词元的上下文窗口开始,最终达到 12.8万 个词元的上下文窗口。长上下文预训练阶段大约使用了 8000亿 个训练词元。
3.4.3、退火
在使用最后 4000 万个词元进行预训练期间,将学习率线性退火至 0,同时保持 12.8 万个词元的上下文长度。在退火阶段调整了数据组合,对高质量数据源进行上采样。最后,在退火过程中,计算模型检查点的平均值,以生成最终的预训练模型。
3.5、后训练
通过进行多轮后训练,即基于预训练检查点让模型与人类反馈对齐,来生成经过对齐的 LLaMA 3 模型。每一轮后训练都涉及有监督微调(SFT),然后在通过人类注释收集或合成生成的示例上执行直接偏好优化(DPO)。
后训练策略的核心是一个奖励模型和一个语言模型:
- 首先,基于预训练检查点,使用人工标注的偏好数据来训练一个奖励模型
- 然后,通过有监督微调(SFT)对预训练检查点进行微调,并进一步使用直接偏好优化(DPO)使检查点与人类偏好对齐
3.5.1、聊天对话格式
为了调整大型语言模型以适应人机交互,需要定义一个聊天对话协议,以便模型理解人类指令并执行对话任务。与其前身 LLaMA 2 相比,LLaMA 3 具备了新的能力(如工具使用),这需要生成多条消息并在单个对话轮次中发送到不同位置(如,用户、ipython)。为此 LLaMA 3 设计一种新的多消息聊天协议,该协议使用各种特殊头部和终止标记(头部标记用于指示对话中每条消息的来源和目的地,终止标记表示何时应该交替让人类和 AI 发言)。
3.5.2、奖励建模(Reward Modeling,RM)
在预训练检查点的基础上训练一个涵盖不同能力的奖励模型(RM)。训练目标与 LLaMA 2 相同,不过去掉了损失函数中的边距项,因为数据规模扩大后,该项带来的提升效果逐渐减弱。和 LLaMA 2 一样,在过滤掉回复相似的样本后,使用所有偏好数据进行奖励模型训练。除了标准的(选中,拒绝)回复偏好对之外,标注人员还会针对某些提示生成第三种 “编辑后回复”,也就是对偏好对中选中的回复进一步编辑完善。因此,每个偏好排序样本有两个或三个有明确排序的回复(编辑后回复 > 选中回复 > 拒绝回复)。训练时,将提示和多个回复拼接成一行,并对回复进行随机打乱。近似于将回复分别放在不同行并计算分数的标准做法,这种方法能在不降低准确率的前提下提高训练效率。
3.5.3、监督微调(Supervised Finetuning,SFT)
利用奖励模型对人工标注提示进行拒绝采样,将经过拒绝采样得到的数据与其他数据源(包括合成数据)相结合,使用标准的交叉熵损失函数,针对目标词元对预训练语言模型进行微调(同时对提示词元的损失进行掩码处理)。这一阶段称为有监督微调(SFT),尽管许多训练目标是由模型生成的。对最大规模的模型进行微调时,学习率设置为 $10^{-5}$,训练步数为 8500 到 9000 步。这些超参数设置在不同轮次和不同数据组合的情况下都能取得良好效果。
3.5.4、直接偏好优化(Direct Preference Optimization,DPO)
进一步使用直接偏好优化(DPO)对经过有监督微调(SFT)的模型进行训练,以使其与人类偏好对齐。在训练过程中,主要使用从之前对齐轮次中表现最佳的模型所收集到的最新批次的偏好数据。因此,训练数据能更好地符合每一轮中待优化的策略模型的数据分布。相比于近端策略优化算法(PPO),对于大规模模型而言,DPO 所需的计算资源更少,且表现更优。
LLaMA 3 使用 $10^{-5}$ 的学习率,并将超参数 $\beta$ 设置为 0.1。此外,对 DPO 进行了改进:
- 在 DPO 损失中屏蔽格式化词元:在损失计算中,从选中和拒绝的回复里屏蔽掉包括标题和终止词元在内的特殊格式化词元,以此来稳定 DPO 训练。若让这些词元影响损失计算,可能会导致尾部重复或突然生成终止词元。推测这可能是由于 DPO 损失的对比性质导致的,在被选中的响应和被拒绝的响应中共同存在的词元会导致学习目标冲突,因为模型需要同时增加和减少这些词元的生成概率
- 使用负对数似然损失进行正则化:在选中序列上添加一个额外的负对数似然(NLL)损失项,缩放系数设为 0.2。这有助于进一步稳定 DPO 训练,能让生成内容保持所需的格式,并防止选中回复的对数概率降低。
3.6、推理
LLaMA 3 405B 模型采用流水线并行和 FP8 量化实现推理优化。
3.6.1、流水线并行
当使用 BF16 数值表示模型参数时,LLaMA 3 405B 无法装入配备 8 块英伟达 H100 GPU 机器的 GPU 内存中。为此,使用 BF16 精度在两台机器共 16 块 GPU 上对模型推理进行并行处理。在每台机器内部,采用高带宽的 NVLink 能够使用张量并行。然而,跨节点连接的带宽较低且延迟较高,因此使用流水线并行。
在使用流水线并行进行训练时,流水线气泡是一个主要的效率问题,但在推理过程中不成问题,推理不涉及需要刷新流水线的反向传播过程。因此,在流水线并行中使用微批次处理来提高推理吞吐量。在相同的批次大小下,微批次处理可以提高推理吞吐量。由于微批处理带来的额外同步点也增加了延迟,但总体而言,微批处理仍然带来了更好的吞吐量-延迟权衡。
3.6.2、FP8 量化
为实现低精度推理,模型内部的大多数矩阵乘法操作应用 FP8 量化。具体而言,对模型前馈网络层中的大多数参数和激活值进行量化,这大约占推理计算时间的 50%,此外不对模型自注意力层中的参数进行量化。同时,借助动态缩放因子来提高精度,并对 CUDA 内核进行优化,以减少计算缩放因子的开销。LLaMA 3 405B 的性能对某些类型的量化较为敏感,因此进行了一些额外调整以提高模型输出质量:
- 不对首个和最后一个 Transformer 层进行量化
- 高困惑度的词元(如日期)可能会导致较大的激活值。这进而会使 FP8 中的动态缩放因子变得很高,引发不可忽视的下溢情况,从而导致解码错误。为解决这一问题,将动态缩放因子的上限设定为 1200
- 采用逐行量化方法,为参数矩阵和激活矩阵逐行计算缩放因子。这种方法比张量级别的量化方法效果更好
参考文献
- LLaMA: Open and Efficient Foundation Language Models
- LLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU
- Root Mean Square Layer Normalization
- GLU Variants Improve Transformer
- ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Llama 2: Open Foundation and Fine-Tuned Chat Models
- The Llama 3 Herd of Models