基于Python的自然语言处理系列(50):Soft Prompt 实现
在本篇文章中,我们将实现一个简单的 Soft Prompt 技术,该技术允许我们仅微调新增的嵌入权重,而保持预训练模型不变。Soft Prompt 的主要优势在于它的参数高效性,使得模型在特定任务上快速适应,而无需重新训练模型的所有权重。
1. Soft Prompt 概述
Soft Prompt 技术来源于论文 The Power of Scale for Parameter-Efficient Prompt Tuning。它通过在模型输入嵌入层添加可训练的软提示嵌入(soft prompt embeddings),使得我们可以仅微调这些新增嵌入,达到适应新任务的目的。这种方法不仅保留了原始模型的完整性,还大幅减少了训练所需的参数和时间成本。
2. 加载 GPT2 模型
我们首先加载预训练的 GPT2 语言模型,并检查其原始的输入嵌入。
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import os
import torch
import torch.nn as nn
# 设置代理(如果需要)
os.environ['http_proxy'