llms 文本分类的微调
文章目录
- 文本分类的微调
- 为增加性能,可以微调最后transformer层
- 微调的损失函数
- 进行微调
- spam
文本分类的微调
分类微调模型只能预测它在训练期间看到的类(例如,“垃圾邮件”或“非垃圾邮件”),而指令微调模型通常可以执行许多任务。
注意,微调往往使用监督学习。
准备数据,下载并解压缩数据集。
微调的目标是 gpt2 ,下载 gpt2 的权重并加载进模型。
model_size = CHOOSE_MODEL.split(" “)[-1].lstrip(”(“).rstrip(”)")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir=“gpt2”)
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval();
目标是替换和微调输出层
,为了实现这一点,首先冻结模型,这意味着我们使所有层都不可训练
for param in model.parameters():
param.requires_grad = False
然后,我们替换输出层 ( model.out_head ),它最初将层输入映射到 50,257 个维度(词汇表的大小)
由于我们对二元分类的模型进行了微调(预测 2 个类,“垃圾邮件”和“非垃圾邮件”),因此我们可以替换如下所示的输出层,默认情况下该层是可训练的。
torch.manual_seed(123)
num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG[“emb_dim”], out_features=num_classes)
为增加性能,可以微调最后transformer层
从技术上讲,仅训练输出层就足够了, 在实验中发现的那样,微调额外的层可以显着提高性能(一开始就忘了,loss 很差而且不收敛)
。
因此,我们还使最后一个 Transformer 模块和最后一个连接到 Transformer 层的 LayerNorm 模块可训练。
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
微调的损失函数
在开始微调(/training)之前,我们首先必须定义要在训练期间优化的损失函数
目标是最大限度地提高模型的垃圾邮件分类准确性;但是,但分类精度不是一个可微函数
因此,将交叉熵损失最小化,作为最大化分类准确性的代理
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]
# Logits of last output token
# 唯一的不同,只关注最后一个 token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
基于因果注意机制,第 4 个(最后一个)token 包含的信息在所有 token 中最多,因为它是唯一包含所有其他 token 信息的 token 。
进行微调
微调和预训练一样。根据损失函数来计算梯度,更新权重。
最后看看微调的效果。
text_1 = (
“You are a winner you have been specially”
" selected to receive $1000 cash or a $2000 award."
)
print(classify_review(
text_1, model, tokenizer, device, max_length=train_dataset.max_length
))
spam
text_2 = (
“Hey, just wanted to check if we’re still on”
" for dinner tonight? Let me know!"
)
print(classify_review(
text_2, model, tokenizer, device, max_length=train_dataset.max_length
))