当前位置: 首页 > article >正文

llms 文本分类的微调

文章目录

  • 文本分类的微调
  • 为增加性能,可以微调最后transformer层
  • 微调的损失函数
  • 进行微调
  • spam

文本分类的微调

分类微调模型只能预测它在训练期间看到的类(例如,“垃圾邮件”或“非垃圾邮件”),而指令微调模型通常可以执行许多任务。
注意,微调往往使用监督学习。
准备数据,下载并解压缩数据集。

微调的目标是 gpt2 ,下载 gpt2 的权重并加载进模型。
model_size = CHOOSE_MODEL.split(" “)[-1].lstrip(”(“).rstrip(”)")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir=“gpt2”)

model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval();

目标是替换和微调输出层,为了实现这一点,首先冻结模型,这意味着我们使所有层都不可训练
for param in model.parameters():
param.requires_grad = False
然后,我们替换输出层 ( model.out_head ),它最初将层输入映射到 50,257 个维度(词汇表的大小)
由于我们对二元分类的模型进行了微调(预测 2 个类,“垃圾邮件”和“非垃圾邮件”),因此我们可以替换如下所示的输出层,默认情况下该层是可训练的。
torch.manual_seed(123)

num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG[“emb_dim”], out_features=num_classes)

为增加性能,可以微调最后transformer层

从技术上讲,仅训练输出层就足够了, 在实验中发现的那样,微调额外的层可以显着提高性能(一开始就忘了,loss 很差而且不收敛)
因此,我们还使最后一个 Transformer 模块和最后一个连接到 Transformer 层的 LayerNorm 模块可训练。
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

for param in model.final_norm.parameters():
param.requires_grad = True

微调的损失函数

在开始微调(/training)之前,我们首先必须定义要在训练期间优化的损失函数
目标是最大限度地提高模型的垃圾邮件分类准确性;但是,但分类精度不是一个可微函数
因此,将交叉熵损失最小化,作为最大化分类准确性的代理
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]
# Logits of last output token
# 唯一的不同,只关注最后一个 token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
基于因果注意机制,第 4 个(最后一个)token 包含的信息在所有 token 中最多,因为它是唯一包含所有其他 token 信息的 token 。

进行微调

微调和预训练一样。根据损失函数来计算梯度,更新权重。
最后看看微调的效果。
text_1 = (
“You are a winner you have been specially”
" selected to receive $1000 cash or a $2000 award."
)

print(classify_review(
text_1, model, tokenizer, device, max_length=train_dataset.max_length
))

spam

text_2 = (
“Hey, just wanted to check if we’re still on”
" for dinner tonight? Let me know!"
)

print(classify_review(
text_2, model, tokenizer, device, max_length=train_dataset.max_length
))


http://www.kler.cn/a/298696.html

相关文章:

  • AAAI 2025论文分享┆一种接近全监督的无训练文档信息抽取方法:SAIL(文中附代码链接)
  • uniapp 微信小程序开发使用高德地图、腾讯地图
  • 通过无障碍服务(AccessibilityService)实现Android设备全局水印显示
  • Python爬虫教程——7个爬虫小案例(附源码)_爬虫实例
  • 未来网络技术的新征程:5G、物联网与边缘计算(10/10)
  • Pytorch | 利用VA-I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
  • 《论多源数据集成及应用》写作框架,软考高级系统架构设计师
  • 【Android】NestedScrollView的简单用法与滚动冲突、滑动冲突
  • 聚观早报 | 红魔电竞平板新品发布;台积电8月份营收
  • LabVIEW步进电机控制方式
  • node.js入门基础
  • Learn OpenGL In Qt之着色器
  • 【C++】 Vector
  • mysql mgr 集群部署 单主模式和多主模式
  • [论文笔记] t-SNE数据可视化
  • Java笔试面试题AI答之JDBC(3)
  • framebuffer
  • Android13修改Setting实现电量低于30%的话不可执行Rest操作
  • ubuntu配置tftp、nfs
  • 【编程基础知识】Spring过滤器、拦截器、AOP区别
  • 《JavaScript 中数据类型判断、转换技巧及应用实例》
  • GitHub每日最火火火项目(9.10)
  • 最新版 | SpringBoot3如何自定义starter(面试常考)
  • Java数组的定义及遍历
  • 【局域网投屏】sunshine和moonlight投屏/屏幕共享/扩展屏
  • LabVIEW软件,如何检测连接到的设备?