当前位置: 首页 > article >正文

【Transformers实战篇1】基于Transformers的NLP解决方案

文章目录

  • 一、基础组件
  • 二、基于Transformer的NLP解决方案
  • 三、显存优化策略
    • 3.1 显存占用简单分析
    • 3.2 Transformers 显存优化
      • 3.2.1 Baseline 不进行任何优化
      • 3.2.2 Gradient Accumulation 梯度累加优化
      • 3.2.3 Gradient Checkpoints 选择性存一些前向激活值
      • 3.2.4 Adafactor Optiomizer 更换优化器
      • 3.2.5 Freeze Model 冻结模型层
      • 3.2.6 Data Length 降低最大长度
      • 3.2.7 更多参数高效微调


本文为 https://space.bilibili.com/21060026/channel/collectiondetail?sid=1357748的视频学习笔记

项目地址为:https://github.com/zyds/transformers-code


一、基础组件

  • Pipeline: 流水线,用于模型推理,封装了完整的推理逻辑,包括数据预处理、模型预测以及后处理
  • Tokenizer: 分词器,用于数据预处理,将原始文本输入转化为模型的输入,包括input_ids、attention_mask等
  • Model: 模型,用于加载、创建、保存模型,对pytorch中的模型进行了封装,同时更好的支持预训练模型
  • Datasets: 数据集,用于数据集加载与预处理,支持加载在线与本地的数据集,提供了数据集层面的处理方法
  • Evaluate: 评估函数,用于对模型的结果进行评估,支持多种任务的评估函数
  • Trainer: 训练器,用于模型训练、评估,支持丰富的配置选项,快速启动模型训练流程

二、基于Transformer的NLP解决方案

  • step1 导入相关包
  • step2 加载数据集 Datasets
  • step3 数据集划分 Datasets
  • step4 数据集预处理 Tokenizer +Datasets
  • step5 创建模型 Model
  • step6 设置评估函数 Evaluate
  • step7 配置训练参数 TrainingArguments
  • step8 创建训练器 Trainer + Data Collator
  • step9 模型训练、评估、预测(数据集) Trainer
  • step10 模型预测(单条) Pipeline

三、显存优化策略

当机器显存不够时,可以选择牺牲模型准确率、牺牲训练时间等方式来换空间

3.1 显存占用简单分析

  • 模型权重: 4 B y t e s × 模型参数量 4 Bytes × 模型参数量 4Bytes×模型参数量
  • 优化器状态: 8 B y t e s × 模型参数量 8 Bytes × 模型参数量 8Bytes×模型参数量 (对常用的Adamw优化器)
  • 梯度: 4 B y t e s × 模型参数量 4 Bytes × 模型参数量 4Bytes×模型参数量
  • 前向激活值: 取决于序列长度、隐层维度、Batch大小等多个因素

3.2 Transformers 显存优化

先贴一张原博主的显存优化数据:
在这里插入图片描述

3.2.1 Baseline 不进行任何优化

BatchSize设置为32, MaxLength设置为128

3.2.2 Gradient Accumulation 梯度累加优化

  • 优化对象 :前向激活值。当累计跑了32个step时,才会计算梯度
  • 参数设置:在 TrainingArguments里设置:per_device_train_batch_size=1gradient_accumulation_steps=32

3.2.3 Gradient Checkpoints 选择性存一些前向激活值

  • 优化对象 :前向激活值。选择性的保存一些前向激活值,然后在反向传播时,再重新计算。
  • 参数设置:在 TrainingArguments里设置:gradient_checkpointing=True

3.2.4 Adafactor Optiomizer 更换优化器

  • 优化对象 :优化器状态。理论比AdamW占用显存小
  • 参数设置:在 TrainingArguments里设置:optim="adafactor"

3.2.5 Freeze Model 冻结模型层

  • 优化对象 :前向激活值和梯度。 通过冻结一些层
  • 参数设置:如将Bert中的参数进行冻结。遍历模型的参数,设置param.requires_grad = False
from transformers import DataCollatorWithPadding

# *** 参数冻结 *** 
for name, param in model.bert.named_parameters():
    param.requires_grad = False

trainer = Trainer(model=model, 
                  args=train_args, 
                  tokenizer=tokenizer,
                  train_dataset=tokenized_datasets["train"], 
                  eval_dataset=tokenized_datasets["test"], 
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)

3.2.6 Data Length 降低最大长度

  • 优化对象:前向激活值。会极大影响模型预测的准确率
  • 参数设置:在tokenizer数据集预处理处进行操作,设置max_length=32
import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-large")

def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=32, truncation=True, padding="max_length")
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2.7 更多参数高效微调

  • Lora
  • cpu offload
  • flash attention等

http://www.kler.cn/news/327241.html

相关文章:

  • 公网IP和内网IP比较
  • 数据结构之手搓顺序表(顺序表的增删查改)
  • plt等高线图的绘制
  • 智能家居技术的前景和现状
  • LeetCode讲解篇之15. 三数之和
  • Frp服务部署
  • 【Qt】Qt安装(2024-10,QT6.7.3,Windows,Qt Creator 、Visual Studio、Pycharm 示例)
  • string为什么存储在堆里
  • EP42 公告详情页
  • Mac制作Linux操作系统启动盘
  • 蜘蛛爬虫的ip来自机房,用户的爬虫来自于哪里
  • 日常工作第10天:
  • web笔记
  • uni-app ios 初次进入网络没有加载 导致出现异常
  • 计算机毕业设计 基于深度学习的短视频内容理解与推荐系统的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档
  • nacos client 本地缓存问题
  • 信息安全数学基础(23)一般二次同余式
  • 正则表达式使用指南(内容详细,通俗易懂)
  • YOLOv8改进 - 注意力篇 - 引入SCAM注意力机制
  • 【2025】基于Spring Boot的智慧农业小程序(源码+文档+调试+答疑)
  • plt绘画三维曲面
  • Android OTA升级
  • excel快速入门(二)
  • Redis缓存技术 基础第二篇(Redis的Java客户端)
  • Ingress Gateway 它负责处理进入集群的 HTTP 和 TCP 流量
  • 七星创客:重塑商业模式认知
  • 在 Linux 中,要让某一个线程或进程排他性地独占一个 CPU
  • AI芯片WT2605C赋能厨房家电,在线对话操控,引领智能烹饪新体验:尽享高效便捷生活
  • Linux:文件描述符介绍
  • 【SpringBoot详细教程】-08-MybatisPlus详细教程以及SpringBoot整合Mybatis-plus【持续更新】