Dataset Distillation with Attention Labels for Fine-tuning BERT
文章使用了DD更新的方式,就是先使用蒸馏数据集训练一个模型,然后计算真实数据在这个模型上的损失,更新蒸馏数据集。
文章的做法是:在训练蒸馏数据集网络时,加入了attention损失
这时候生成数据集不仅仅包含原始数据x
和y
,还包含了a
,这是attention模块的输出,作者只取了[CSL]
模块的输出。
之后使用蒸馏数据集训练模型时,不仅需要x,y的预测损失,还需要加入[cls]
的损失。