GKD
GKD(Generalized Knowledge Distillation,广义知识蒸馏)训练算法由论文 On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes 提出。该算法通过结合离线(off-policy)和在线(on-policy)学习策略,将教师模型的知识迁移到学生模型中。
损失函数
当给定输入序列 \(x\) 与输出序列 \(y\),GKD 的损失函数可以写为:
其中:
\(y_{<t} = (y_1, y_2, \ldots, y_{t-1})\):前 \(t-1\) 个 token 的序列
\(P_{\text{teacher}}(\cdot | x, y_{<t})\):教师模型在给定上下文 \(x, y_{<t}\) 时的输出概率分布
\(P_{\text{student}}(\cdot | x, y_{<t})\):学生模型在给定上下文 \(x, y_{<t}\) 时的输出概率分布
\(D(\cdot, \cdot)\):散度函数,用于度量两个概率分布之间的差异性
散度度量函数
KL 散度(Kullback-Leibler Divergence)
KL 散度是衡量两个概率分布 \(P\) 和 \(Q\) 之间差异的非对称度量:
Forward KL 与 Reverse KL
在知识蒸馏中,根据 KL 散度中两个分布的顺序不同,有两种选择:
Forward KL(前向 KL)
特性:Mode-covering
期望在教师分布下计算
学生模型倾向于覆盖教师的整个分布(包括低概率区域)
Reverse KL(反向 KL)
特性:Mode-seeking
期望在学生分布下计算
学生模型倾向于集中在教师模型的峰值区域(高概率区域)
广义 Jensen-Shannon 散度(Generalized JSD)
GKD 使用广义 JSD 作为核心度量,通过参数 \(\beta \in [0, 1]\) 在 Forward KL 和 Reverse KL 之间进行平滑插值。
对于两个概率分布 \(P\) 和 \(Q\),广义 JSD 定义为:
其中混合分布 \(M\) 定义为:
当 \(\beta = 0.5\) 时,退化为标准的对称 JSD
通过调节 \(\beta\),可以在 Mode-seeking 和 Mode-covering 之间权衡
在 GKD 中,我们令 \(P = P_{\text{teacher}}\),\(Q = P_{\text{student}}\),因此:
其中 \(M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}\)
对极端情况(\(\beta = 0\) 或 \(\beta = 1\)),直接计算单个 KL 散度:
当 \(\beta = 0\) 时:直接定义 \(D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})\)(Forward KL,Mode-covering)
当 \(\beta = 1\) 时:直接定义 \(D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})\)(Reverse KL,Mode-seeking)
当 \(0 < \beta < 1\) 时:使用上述混合分布公式进行插值
通过调节 \(\beta\) 参数,可以在不同的散度度量之间进行插值,当 \(\beta = 0.5\) 时,散度为标准的对称 JSD。
三种训练模式
GKD训练具有三种训练模式,区别在于输出序列 \(y\) 的来源。
模式选择逻辑
训练时,每个样本按照以下优先级选择模式:
# 伪代码:模式选择逻辑
if random() < lmbda:
# Mode 1: On-Policy 学习,由学生模型采样输出序列
y = student.generate(x)
source = "student"
elif seq_kd:
# Mode 2: Sequential KD,由教师模型采样输出序列
y = teacher.generate(x)
source = "teacher"
else:
# Mode 3: 使用数据集中的输出序列
y = y_ground_truth
source = "dataset"
# 相同的损失函数
loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
Mode 1: On-Policy 学习
设置参数lambda, 以概率 \(\lambda\) 触发,使用学生模型采样 \(y \sim P_{\text{student}}(\cdot | x)\)
学生模型从自己生成的序列中学习
暴露在自己可能犯的错误中,学会自我纠正和错误恢复
对齐训练分布与推理分布
提升模型的鲁棒性和实际应用表现
适用场景:
学生模型已有一定生成能力
希望提升模型在真实推理场景下的表现
Mode 2: Sequential KD(seq_kd=True 且未触发 on-policy)
设置参数 seq_kd=True, 当未触发 on-policy 时,使用教师模型采样
数据来源:\(y \sim P_{\text{teacher}}(\cdot | x)\)
Mode 3: 离线学习(其他情况)
数据来源:\(y = y^* \sim \text{Dataset}\)
学生模型从数据集的标注序列中学习
参数设置
我们可以通过设置以下参数进行 GKD 训练:
基础参数
| 参数 | 类型 | 默认值 | 取值范围 | 说明 |
|---|---|---|---|---|
--teacher_model |
str | None | - | 教师模型路径或模型 ID |
--beta |
float | 0.5 | [0.0, 1.0] | 散度插值系数 • 0.0: Forward KL • 0.5: JSD (平衡) • 1.0: Reverse KL |
--lmbda |
float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率 • 0.0: 离线学习 • 0.5: 混合策略 • 1.0: 纯 On-Policy |
--seq_kd |
bool | False | True/False | 是否使用教师生成序列 • False: 非 on-policy 时使用数据集 • True: 非 on-policy 时使用教师生成 |
--temperature |
float | 0.9 | > 0 | 生成采样温度,控制随机性 |
--sft_alpha |
float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 |
--max_completion_length |
int | 512 | > 0 | 生成时的最大 token 数 |
Top-K KL 计算
默认情况下,GKD 使用完整词表计算 KL 散度,容易造成 OOM,这种情况下可以使用 Top-K 模式来减少显存占用和计算量。
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--gkd_logits_topk |
int | None | Top-K logits 数量 • None: 使用完整词表(默认) • 正整数: 仅使用教师模型概率最高的 K 个 token 计算 KL |
Top-K 模式原理:
在 Top-K 模式下,选取教师模型输出概率最高的 K 个 token,在这个子集上计算两个模型分布的 KL 散度。 $\( D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) \)$
其中 Top-K 索引来自教师模型:\(\text{Top-K} = \text{argtop}_K(P_T)\),\(\tilde{P}_T\) 和 \(\tilde{P}_S\) 是在 Top-K 子集上重新归一化的概率分布:
使用示例:
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B-Instruct \
--teacher_model Qwen/Qwen2.5-72B-Instruct \
--gkd_logits_topk 64 \
--dataset your_dataset \
...
注意:Top-K 模式不能与 liger kernel 同时使用(
--use_liger_kernel)。
外部教师模型 API
当设置 gkd_logits_topk 时,可以使用外部教师模型 API 服务来获取 logprobs,这样可以避免在训练进程中加载教师模型。
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--teacher_model_server |
str | None | 教师模型服务地址 如: http://localhost:8000 |
--gkd_logits_topk |
int | 必需 | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 |
步骤 1:部署教师模型服务
# 使用 vllm serve 部署教师模型
CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
--port 8000 \
--max-logprobs 64 \
--gpu-memory-utilization 0.9
步骤 2:启动 GKD 训练
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B \
--teacher_model_server http://localhost:8000 \
--gkd_logits_topk 64 \
--dataset your_dataset \
--lmbda 1.0 \
--beta 1.0 \
...
vLLM max_logprobs 限制:
vLLM 默认
max_logprobs=20,可通过--max-logprobs N参数调整
gkd_logits_topk不能超过服务端的max_logprobs设置
采样加速
在 GKD 训练中,涉及到两种在线采样的情况:
学生模型采样(当
lmbda > 0):以 \(\lambda\) 概率触发学生模型采样教师模型采样(当
seq_kd=True):以 \(1-\lambda\) 概率触发教师模型采样
由于采样过程会显著减慢训练速度,可参考以下两种加速方案:
方案 1:学生模型采样加速
使用 vLLM 作为推理后端来加速学生模型采样,支持两种部署模式,与 GRPO 一致,参考GRPO文档, 相关参数参考GRPO vLLM 参数
注意:vLLM 加速仅适用于学生模型的 on-policy 采样(
lmbda > 0)。教师模型的 sequential KD 采样(seq_kd=True)目前仍使用 Transformers,建议使用预采样方案。
训练脚本参考这里
使用 Teacher Server 的训练脚本参考这里
方案 2:教师模型预采样
对于教师模型采样(seq_kd=True),推荐使用 预采样 方式:先用教师模型离线生成高质量数据,再进行训练。
步骤 1:使用教师模型生成数据
export teacher_model='OpenGVLab/InternVL3-8B'
NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
--model $teacher_model \
--infer_backend vllm \
--val_dataset 'modelscope/coco_2014_caption:validation#5000' \
--vllm_gpu_memory_utilization 0.9 \
--vllm_max_model_len 8192 \
--max_new_tokens 2048 \
--write_batch_size 1000 \
--result_path teacher_generated_data.jsonl
步骤 2:使用预生成数据训练
swift rlhf \
--rlhf_type gkd \
--model OpenGVLab/InternVL3-2B-Pretrained \
--teacher_model $teacher_model \
--dataset 'teacher_generated_data.jsonl' \
--seq_kd false \
...
训练脚本参考这里
On-Policy Distillation
我们可以通过设置以下参数实现 Thinking Machine Lab blog 中的On-Policy Distillation训练。
--lmbda 1 # on-policy
--beta 1 # reverse
相关脚本可以参考这里
OPSD(On-Policy Self-Distillation)
OPSD(On-Policy Self-Distillation) 是一种单模型自蒸馏方法,无需额外的教师模型。核心思想是:同一个模型同时扮演教师和学生,教师通过接收特权信息(如参考解答)来引导学生学习。
核心机制
学生:仅看到问题,正常推理
教师:看到问题 + 参考解答(通过
teacher_prompt列提供特权信息),产出更优的概率分布训练目标:用 JSD 散度对齐学生和教师的输出分布
两种自蒸馏模式
| 模式 | 参数配置 | 教师权重 | 说明 |
|---|---|---|---|
| Dynamic(动态) | 不传 --teacher_model |
学生当前权重 | 教师随训练同步更新 |
| Fixed(固定) | --teacher_model 设为与学生相同的模型 |
初始教师权重 | 教师权重固定 |
数据格式
OPSD 需要数据集包含 teacher_prompt 列来提供教师的特权信息。可通过 --external_plugins 加载数据处理插件来构建该列。
以数学推理数据集 open-r1/OpenThoughts-114k-math 为例:
from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset
class OpenThoughtsOPSDPreprocessor(RowPreprocessor):
def preprocess(self, row):
if not row.get('correct', True):
return None
problem = row.get('problem', '')
solution = row.get('solution', '')
# 教师看到问题 + 参考解答
teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.'
messages = [
{'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'},
{'role': 'user', 'content': problem},
]
return {'messages': messages, 'teacher_prompt': teacher_prompt}
register_dataset(DatasetMeta(
ms_dataset_id='open-r1/OpenThoughts-114k-math',
preprocess_func=OpenThoughtsOPSDPreprocessor(),
tags=['math', 'opsd'],
))
参数设置
OPSD 复用 GKD 的所有参数,核心区别在于 --teacher_model 的配置:
| 参数 | Dynamic 模式 | Fixed 模式 |
|---|---|---|
--teacher_model |
不设置 | 设为与 --model 相同的模型 |
参考脚本参考这里
Megatron脚本参考这里