语言模型的知识蒸馏

chaos 2025-02-23 日 22:14
+-------------------+         +-------------------+
|                   |         |                   |
|   Teacher Model   |         |   Student Model   |
|(Large, Cumbersome)|         |  (Small, Simple)  |
|                   |         |                   |
+--------+----------+         +---------+---------+
         |                             |
         |                             |
         |                             |
         v                             v
+-------------------+         +-------------------+      +-------------------+
|                   |         |                   |      |                   |
|   Soft Targets    |         |      logits       |      |   Hard Targets    |
|  (Probabilities)  |  ----   |      Softmax      | ---- |  (One-hot Labels) |
|                   |         |                   |      |                   |
+--------+----------+         +---------+---------+      +---------+---------+
         |                             |
         |                             |
         |                             |
         v                             v
+-----------------------------------------------+
|                                               |
|                Combined Loss                  |
|  L = (1 - α) * L_soft + α * L_hard            |
|                                               |
+-----------------------------------------------+

梯度缩放的问题

蒸馏学习中,总损失为硬目标和软目标的加权和:

\[ \mathcal{L} = \alpha \mathcal{L}_{\text{hard}} + (1-\alpha) \mathcal{L}_{\text{soft}}. \]

其中:

  • 硬目标为真实标签的 one-hot 编码 \( y = [y_1, y_2, ..., y_C] \),其中 \( y_k = 1 \)(正确类别),其他 \( y_i = 0 \)。
  • 软目标为教师模型输出的概率分布 \( p = [p_1, p_2, ..., p_C] \),通过带温度 \( T \) 的 softmax 计算得到:

    \[ p_i = \frac{\exp(v_i / T)}{\sum_j \exp(v_j / T)}, \]

    其中 \( v_i \) 是教师模型的 logits。

  • 学生模型的输出为 \( q = [q_1, q_2, ..., q_C] \),同样通过带温度 \( T \) 的 softmax 计算:

    \[ q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}, \]

    其中 \( z_i \) 是学生模型的 logits。

    当和硬目标计算交叉熵损失的时候,T 为 1, 当要突出 T 的值时,会将 \( q_{i} \) 写成 \( q_{i}^{T} \)

硬目标损失的梯度

硬目标损失是学生模型输出 \( q \) 与真实标签 \( y \) 的交叉熵:

\[ \mathcal{L}_{\text{hard}} = -\sum_i y_i \log(q_i). \]

假设正确类别是 \( k \),那么 \( y_k = 1 \),其他都为 0 ,损失中求和就只有一项:

\[ \mathcal{L}_{\text{hard}} = - \log(q_k). \]

对 logits \( z_k \) 的梯度为:

\[ \frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_k} = -\frac{1}{q_{k}} \frac{\partial \log(q_k)}{\partial z_k} = -\frac{1}{q_{k}}\frac{\frac{1}{T}\exp(z_k / T)\sum_j \exp(z_j / T) -\frac{1}{T} \exp(z_k / T)\exp(z_k / T)}{(\sum_j \exp(z_j / T))^{2}} \]

化简得到: \[ \frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_k} = -\frac{1}{Tq_{k}}(q_{k}(1-q_{k})) = \frac{1}{T}(q_{k}-1) \]

对 logtis 中 \( z_{i}, (i\ne k)\) 的梯度为:

\[ \frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_j} = -\frac{1}{q_{k}} \frac{\partial \log(q_k)}{\partial z_j} = -\frac{1}{q_{k}}\frac{ -\frac{1}{T} \exp(z_k / T)\exp(z_j / T)}{(\sum_j \exp(z_j / T))^{2}} \]

化简得到:

\[ \frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_j} = \frac{1}{Tq_{k}}(q_{k}q_{j}) = \frac{1}{T}q_{j} = \frac{1}{T}(q_{j}-0) \]

最后写成 \( q_{j}-0 \) 是可以和 \( q_{k}-1 \) 形式统一,额外的,由于 T=1 因此,任意硬损失对任意 \( z_{i} \) 的梯度可以写成

\[ \frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_i} = \frac{1}{T}(q_{j}-y_{j}) = (q_{j}^{1}-y_{j}) \]

软目标损失的梯度

软目标损失是学生模型输出 \( q \) 与老师模型输出 \( p \) 的交叉熵:

软目标损失为: \[ \mathcal{L}_{\text{soft}} = -\sum_i p_i \log(q_i). \]

其中

\[ p_i = \frac{\exp(v_i / T)}{\sum_j \exp(v_j / T)}. \]

\[ q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}. \]

对 \( z_k \) 求导: \[ \frac{\partial \mathcal{L}_{\text{soft}}}{\partial z_k} = -\sum_i p_i \cdot \frac{\partial \log(q_i)}{\partial z_k}. \]

这里不像上一节那样先对 log 求导在对其中 softmax, 而是把 log(q) 展开:

\[ \log(q_i) = \frac{z_i}{T} - \log\left(\sum_j \exp(z_j / T)\right), \]

其导数更容易计算:

\[ \frac{\partial \log(q_i)}{\partial z_k} = \begin{cases} \frac{1}{T} (1 - q_k) & \text{if } i = k, \\ -\frac{1}{T} q_k & \text{if } i \neq k. \end{cases} \]

代入梯度公式: \[ \frac{\partial \mathcal{L}_{\text{soft}}}{\partial z_k} = -\frac{1}{T} \left( p_k (1 - q_k) - \sum_{i \neq k} p_i q_k \right) = \frac{1}{T} \left( \sum_{i} p_i q_k -p_k\right) = \frac{1}{T} (q_k - p_k). \]

这和硬损失的梯度形式:

\[ \frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_i} = \frac{1}{T}(q_{j}-y_{j}) \]

实际是一样的,但要注意的是,硬损失梯度公式中 \( q_{j} \) 实际是 \( q_{j}^{1} \), 而软损失中 \( q_{j} \) 是 \( q_{j}^{T} \)

梯度比例的分析

现在我们看软目标梯度 \( \frac{1}{T}(q_{j}^{T}-p^{T}_{j}) \) 和硬目标梯度 \( (q_{j}^{1}-y_{j}) \) 的比例关系

这里需要借助一些近似,首先对于 x 很小的时候,有 \( \log (1+x) \approx x\), 那么当温度 T 远大于 logit \( z_{i} \) 时, \( \log(1+\frac{z_{i}}{T})\approx \frac{z_{i}}{T} \)

因此有: \[ 1+\frac{z_{i}}{T} \approx e^{\frac{z_{i}}{T}} \]

那么对于较大的 T (高温蒸馏) \[ \frac{\partial \mathcal{L}_{\text{soft}}}{\partial z_k} = \frac{1}{T} (q_{k}-p_{k}) \approx \frac{1}{T}\left(\frac{1+\frac{z_{k}}{T}}{C+\sum_{j}\frac{z_{j}}{T}}-\frac{1+\frac{v_{k}}{T}}{C+\sum_{j}\frac{v_{j}}{T}}\right) \]

这里 C 是类别数,如果两个模型的 logits 层的均值都接近 0 (均值不影响 softmax 结果),那么:

\[ \frac{\partial \mathcal{L}_{\text{soft}}}{\partial z_k} \approx \frac{1}{T^{2}C}(z_{k}-v_{k}) \]

因此可以认为该梯度正比于 \( \frac{1}{T^{2}} \) ,而硬目标的梯度大小为 \( q_k - y_k \propto 1 \) (注意这是一种类似分析算法复杂度的尺度估计方式)

因此最终软目标的损失应该定义为:

\[ \mathcal{L}_{\text{soft}} = -T^{2}\sum_i p_i \log(q_i). \]

最终损失为 \( \alpha \mathcal{L}_{\text{hard}}+(1-\alpha)\mathcal{L}_{\text{soft}} \)

logits 之间的平方损失

\[ \frac{\partial \mathcal{L}_{\text{soft}}}{\partial z_k} \approx \frac{1}{T^{2}C}(z_{k}-v_{k}) \]

公式还提供了一个信息,即当 T 较大模型的 logits 均值为 0 时,该梯度等价于 \(\frac{1}{2}(z_{k}-v_{k})^{2}\) 作为损失函数的场景,因此后者(称为 matching logit )是知识蒸馏的一个特例。

matching logit 损失的梯度的特点是,对每个独立的 logit, \( z_{k}-v_{k} \) 实际可以是任意的,它们和最终的概率在尺度上没有什么直接关联(softmax 关注的是不同 logit 的大小关系)

teacher_logit = [1,2,4,8]
student_logit = [2,4,8,16]

这意味着,对于任何分类对应的 logit, matching logit 损失的梯度的大小都是一样的。

而对于 T 较小的时候,蒸馏损失的梯度: \[ \frac{\partial \mathcal{L}_{\text{soft}}}{\partial z_k} = \frac{1}{T} (q_k - p_k). \]

如果学生模型的预测 \( q_{k} \) 和教师的答案 \( p_{k} \) 很不一样,那么梯度的尺度(绝对值)就会变大,意味着错误更容易纠正,而那些更细微的差别更不容易被更新。

因此温度 T 越大,“蒸馏”的概念也就越凸显,它意味着那些更小的差异的类别会得到更多关注。

如何选择 T 是需要经验验证的,因为更小差异的类别中实际包括噪音,但相比于硬标签, soft 标签里这些小概率的值实际是“泛化”所需要的核心信息。

在 2015 论文中第 3 节提到

When the distilled net had 300 or more units in each of its two hidden layers, all temperatures above 8 gave fairly similar results. But when this was radically reduced to 30 units per layer, temperatures in the range 2.5 to 4 worked significantly better than higher or lower temperatures.

$3 Distilling the Knowledge in a Neural Network

这从经验上说明,学生模型的容量较大的时候,可以用更高的温度去学习(蒸馏),多关注 soft target 里概率分配低的那些标签,从而获得更好的泛化信息,而当模型容量更小的时候,这些信息可能带来的噪音比例更为明显,因此用较小的温度(这时候 soft loss 部分的 \( T^{2} \) 也会减小,即训练中多关注原始分类里的知识)。

radioLinkPopups

如对本文有任何疑问,欢迎通过 github issue 邮件 metaescape at foxmail dot com 进行反馈