Python 【大模型】之 使用千问Qwen2-VL 大模型训练LaTeX数学公式图,并进行LaTeX图识别测试
Python 【大模型】之 使用千问Qwen2-VL 大模型训练LaTeX数学公式图,并进行LaTeX图识别测试
目录
Python 【大模型】之 使用千问Qwen2-VL 大模型训练LaTeX数学公式图,并进行LaTeX图识别测试
一、简单介绍
二、千问 Qwen2-VL
三、LaTeX 公式
四、环境准备
1、环境
2、 pip 安装的一些主要 package
3、创建一个虚拟环境
五、图片数据准备,模型下载、训练、和测试
1、用于训练的图片数据下载与归档
2、Qwen2-VL 模型下载、训练、和测试
六、工程下载
附录
一、简单介绍
Python是一种跨平台的计算机程序设计语言。是一种面向对象的动态类型语言,最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。Python是一种解释型脚本语言,可以应用于以下领域: Web 和 Internet开发、科学计算和统计、人工智能、教育、桌面界面开发、软件开发、后端开发、网络爬虫。
Python 机器学习是利用 Python 编程语言中的各种工具和库来实现机器学习算法和技术的过程。Python 是一种功能强大且易于学习和使用的编程语言,因此成为了机器学习领域的首选语言之一。Python 提供了丰富的机器学习库,如Scikit-learn、TensorFlow、Keras、PyTorch等,这些库包含了许多常用的机器学习算法和深度学习框架,使得开发者能够快速实现、测试和部署各种机器学习模型。
通过 Python 进行机器学习,开发者可以利用其丰富的工具和库来处理数据、构建模型、评估模型性能,并将模型部署到实际应用中。Python 的易用性和庞大的社区支持使得机器学习在各个领域都得到了广泛的应用和发展。
二、千问 Qwen2-VL
2023 年 8 月,通义千问开源第一代视觉语言理解模型 Qwen-VL,成为开源社区最受欢迎的多模态模型之一。短短一年内,模型下载量突破 1000 万次。目前,多模态模型在手机、车端等各类视觉识别场景的落地正在加速,开发者和应用企业也格外关注 Qwen-VL 的升级迭代。
相比上代模型,Qwen2-VL 的基础性能全面提升,可以读懂不同分辨率和不同长宽比的图片,在 DocVQA、RealWorldQA、MTVQA 等基准测试创下全球领先的表现;可以理解 20 分钟以上长视频,支持基于视频的问答、对话和内容创作等应用;具备强大的视觉智能体能力,可自主操作手机和机器人,借助复杂推理和决策的能力,Qwen2-VL 可以集成到手机、机器人等设备,根据视觉环境和文字指令进行自动操作;能理解图像视频中的多语言文本,包括中文、英文,大多数欧洲语言,日语、韩语、阿拉伯语、越南语等
千问Qwen2-VL是阿里通义千问推出的新一代视觉语言模型,具备以下特点和功能:
1. 强大的视觉理解能力:
- Qwen2-VL能够识别任意分辨率的图像,无论图像的清晰度或大小如何,都能轻松识别。其独特的naive dynamic resolution支持将任意分辨率的图像映射成动态数量的视觉token,模拟人类视觉感知的自然方式。
- 该模型还能理解超过20分钟的长视频,通过在线流媒体能力,支持高质量的视频问答、对话和内容创作等应用。
2. 多语言支持:
- Qwen2-VL支持英语、中文以及包括欧洲语言、日语、韩语、阿拉伯语、越南语在内的多语言上下文理解,打破了语言障碍,为多语言环境下的应用提供了便利。
3. 视觉智能体能力:
- Qwen2-VL凭借先进的推理和决策能力,可以与手机、机器人等设备集成,实现基于视觉输入和文本指令的自主操作。
4. 模型架构:
- Qwen2-VL延续了ViT加Qwen2的串联结构,三个尺寸的模型都采用了600M规模大小的ViT,支持图像和视频统一输入。在架构上进行了升级,包括实现了对原生动态分辨率的全面支持和使用了多模态旋转位置嵌入(M-RoPE)方法,使得大规模语言模型能够同时捕捉和整合一维文本序列、二维视觉图像以及三维视频的位置信息。
5. 模型性能:
- 在多个权威测评中,Qwen2-VL创造了同等规模开源模型的最佳成绩。在mathvista、docvqa、realworldqa、mtvqa等基准测试中创下全球领先的表现,在文档理解方面优势尤其明显。与GPT-4O和Claude3.5-Sonnet等闭源模型相比,Qwen2-VL在大部分指标上都达到了最优。
6. 模型下载与推理:
- Qwen2-VL进行了开源,包含两个尺寸的模型,分别是Qwen2-VL-2B-Instruct以及Qwen2-VL-7B-Instruct,并提供了其GPTQ和AWQ的量化版本。模型可以通过ModelScope CLI进行下载,并提供了详细的安装依赖和模型推理步骤。
Qwen2-VL的发布为多模态技术的发展注入了新的活力,其在图像和视频理解方面取得了显著的突破,并具备强大的视觉智能体能力,能够与各种设备进行交互,为用户带来全新的体验。
三、LaTeX 公式
LaTeX 公式是一种使用 LaTeX 语言编写的数学公式。LaTeX 是一种排版系统,特别适合于数学和科学文档的排版,因为它能够精确地渲染复杂的数学表达式和符号。
在 LaTeX 中,公式通常被包含在数学模式中,这可以通过以下方式进入:
- 行内数学模式:使用两个美元符号
$...$
包围公式,例如$a^2 + b^2 = c^2$
。- 显示数学模式:使用两个美元符号
$$...$$
包围公式,或者使用\begin{equation}... \end{equation}
环境,例如$$a^2 + b^2 = c^2$$
或者
LaTeX 公式可以包含各种数学符号和结构,例如分数、根号、求和符号、积分符号、矩阵等。这些都可以通过特定的 LaTeX 命令来创建。
例如,以下是一些基本的 LaTeX 公式示例:
- 分数:
\frac{a}{b}
渲染为 。- 根号:
\sqrt{x}
渲染为 。- 求和符号:
\sum_{i=1}^{n} a_i
渲染为 。- 积分符号:
\int_{a}^{b} f(x) \, dx
渲染为。- 矩阵:使用
\begin{pmatrix}... \end{pmatrix}
环境,例如
LaTeX 公式在学术出版、科学文档和数学教育中被广泛使用,因为它们能够提供高质量的数学表达式渲染。
最后,你可以到 LaTex 公式编辑器网站(在线LaTeX公式编辑器-编辑器)上了解更多相关知识点,以及验证公式。
四、环境准备
1、环境
案例环境:1) Windows 10;2)Python 3.11
构建虚拟环境,安装相关包,主要是:torch、torchvision、transforms
如果使用 cuda 进行训练,查看自己的 cuda 版本对应安装 torch 相关
案例中 cuda 版本为 12.3,所以对应安装 torch 如下命令:
pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
2、 pip 安装的一些主要 package
python -m pip install --upgrade pip
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install modelscope==1.18.0
pip install transformers==4.46.2
pip install sentencepiece==0.2.0
pip install accelerate==1.1.1
pip install datasets==2.18.0
pip install peft==0.13.2
pip install swanlab==0.3.27
pip install qwen-vl-utils==0.0.8
pip install pandas==2.2.2
pip install oss2==2.19.1
pip install addict==2.4.0
pip install pillow==11.0.0
下面是对应解释:
python -m pip install --upgrade pip
- 这个命令用于升级pip工具本身到最新版本。保持pip更新可以确保您使用的是最新特性和安全修复。
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
- 这个命令将pip的默认源更换为清华大学提供的PyPI镜像源。这样做可以加速库的安装,因为清华大学的镜像源在国内访问速度较快,可以减少下载时间。
pip install modelscope==1.18.0
- 安装ModelScope库的特定版本1.18.0。ModelScope是一个提供多种预训练模型的平台,用于机器学习和深度学习任务。
pip install transformers==4.46.2
- 安装transformers库的特定版本4.46.2。这个库由Hugging Face提供,包含了大量的预训练模型,如BERT、GPT等,用于自然语言处理任务。
pip install sentencepiece==0.2.0
- 安装SentencePiece库的特定版本0.2.0。SentencePiece是一个用于文本分词的库,支持多种语言,常用于机器学习和自然语言处理任务。
pip install accelerate==1.1.1
- 安装Accelerate库的特定版本1.1.1。Accelerate是一个由Hugging Face提供的库,用于简化深度学习模型的分布式训练。
pip install datasets==2.18.0
- 安装datasets库的特定版本2.18.0。datasets库提供了一个简单的接口来加载、处理和分享数据集,常用于机器学习项目。
pip install peft==0.13.2
- 安装peft库的特定版本0.13.2。peft是一个轻量级的Python配置文件处理库,用于处理配置文件。
pip install swanlab==0.3.27
- 安装SwanLab库的特定版本0.3.27。SwanLab是一个用于机器学习和深度学习的库,提供了一些工具和功能来简化开发流程。
pip install qwen-vl-utils==0.0.8
- 安装qwen-vl-utils库的特定版本0.0.8。这个库可能与Qwen-VL模型相关,提供了一些工具和实用程序来支持Qwen-VL模型的使用。
pip install pandas==2.2.2
- 安装pandas库的特定版本2.2.2。pandas是一个强大的数据分析和操作库,提供了DataFrame等数据结构,广泛用于数据处理和分析。
pip install oss2==2.19.1
- 安装oss2库的特定版本2.19.1。oss2是阿里云对象存储服务(OSS)的Python SDK,用于在Python程序中操作阿里云OSS服务。
pip install addict==2.4.0
- 安装addict库的特定版本2.4.0。addict是一个轻量级的Python字典对象,可以像访问属性一样访问字典的键值。
pip install pillow==11.0.0
- 安装Pillow库的特定版本11.0.0。Pillow是Python Imaging Library(PIL)的一个分支,用于图像文件的打开、操作和保存。
注意:记得根据需要安装 torch,以及相关
3、创建一个虚拟环境
命令:virtualenv xxxxxProject
可以先切换到自己需要创建文件夹路径,创建管理自己的虚拟环境
(也可以使用Anacoda 管理虚拟环境)
五、图片数据准备,模型下载、训练、和测试
1、用于训练的图片数据下载与归档
1.1、这里使用的是 LaTeX_OCR 的数据集
LaTex_OCR 图片集下载地址:魔搭社区
本仓库有 5 个数据集
small
是小数据集,样本数 110 条,用于测试full
是印刷体约 100k 的完整数据集。实际上样本数略小于 100k,因为用 LaTeX 的抽象语法树剔除了很多不能渲染的 LaTeX。synthetic_handwrite
是手写体 100k 的完整数据集,基于full
的公式,使用手写字体合成而来,可以视为人类在纸上的手写体。样本数实际上略小于 100k,理由同上。human_handwrite
是手写体较小数据集,更符合人类在电子屏上的手写体。主要来源于CROHME
。我们用 LaTeX 的抽象语法树校验过了。human_handwrite_print
是来自human_handwrite
的印刷体数据集,公式部分和human_handwrite
相同,图片部分由公式用 LaTeX 渲染而来。
1.2、数据下载方式
加载训练集
- name 可选 small, full, synthetic_handwrite, human_handwrite, human_handwrite_print
- split 可选 train, validation, test
>>> from modelscope import MsDataset
>>> dataset = MsDataset.load('AI-ModelScope/LaTeX_OCR', subset_name='small')
>>> dataset
DatasetDict({
train: Dataset({
features: ['image', 'text'],
num_rows: 50
})
validation: Dataset({
features: ['image', 'text'],
num_rows: 30
})
test: Dataset({
features: ['image', 'text'],
num_rows: 30
})
})
1.3、使用 Pycharm 创建一个工程,并且使用上之前创建的,安装过对应 package 的虚拟环境
1.4、编写脚本,下载与归档 LaTex_OCR 图片集
实现脚本 ImageDataDownloadThenDataToCSV.py (主要是下载图片,把图片一些相关数据转为 csv)
# 导入所需的库
from modelscope.msdatasets import MsDataset # 从modelscope库中导入MsDataset类,用于加载数据集
import os # 导入os库,用于操作文件和目录
import pandas as pd # 导入pandas库,用于数据处理和CSV文件操作
MAX_DATA_NUMBER = 1000 # 定义处理数据的最大数量
# 定义数据集的名称、子集名称、分割方式和缓存目录
DOWNLOAD_DATA_SET_NAME = 'AI-ModelScope/LaTeX_OCR' # LaTeX_OCR数据集的名称
DOWNLOAD_SUBSET_NAME = 'default' # 数据集的子集名称
DOWNLOAD_SPLIT = 'train' # 数据集的分割方式,这里选择训练集
CACHE_DIR = "../../data" # 定义缓存目录,用于存储下载的数据集
# 定义保存图片和CSV文件的目录路径
SAVED_DIR = '../../data/LaTeX_OCR/images' # 定义保存图片的目录路径
SAVED_CSV_PATH = "../../data/LaTeX_OCR/LaTeX_OCR_dataset.csv" # 定义保存CSV文件的路径
# 检查目录是否已存在
if not os.path.exists(SAVED_DIR):
# 如果目录不存在,则从modelscope下载LaTeX_OCR图像描述数据集
ds = MsDataset.load(DOWNLOAD_DATA_SET_NAME, subset_name=DOWNLOAD_SUBSET_NAME, split=DOWNLOAD_SPLIT,
cache_dir=CACHE_DIR)
print(len(ds)) # 打印数据集的大小
# 设置处理的图片数量上限
total = min(MAX_DATA_NUMBER, len(ds))
# 创建保存图片的目录
os.makedirs(SAVED_DIR, exist_ok=True) # 如果目录不存在,则创建它
# 初始化存储图片路径和文本的列表
image_paths = [] # 用于存储图片路径的列表
texts = [] # 用于存储文本的列表
# 遍历数据集中的样本
for i in range(total):
# 获取每个样本的信息
item = ds[i]
text = item['text'] # 获取文本内容
image = item['image'] # 获取图片对象
# 保存图片并记录路径
image_path = os.path.abspath(f'{SAVED_DIR}/{i}.jpg') # 构建图片的完整路径
image.save(image_path) # 保存图片到指定路径
# 将路径和文本添加到列表中
image_paths.append(image_path) # 添加图片路径到列表
texts.append(text) # 添加文本到列表
# 每处理50张图片打印一次进度
if (i + 1) % 50 == 0:
print(f'Processing {i + 1}/{total} images ({(i + 1) / total * 100:.1f}%)')
# 将图片路径和文本保存为CSV文件
df = pd.DataFrame({
'image_path': image_paths, # 创建DataFrame,包含图片路径列
'text': texts # 创建DataFrame,包含文本列
})
# 将数据保存为CSV文件
df.to_csv(SAVED_CSV_PATH, index=False) # 保存DataFrame到CSV文件,不包含索引
print(f'数据处理完成,共处理了{total}张图片') # 打印处理完成的信息
else:
print(f'{SAVED_DIR}目录已存在,跳过数据处理步骤') # 如果目录已存在,则跳过数据处理步骤
下载的图片集:
生成的 csv 数据内容如下图:
1.5 把归档的图片数据进行 json 数据整理
将这些数据整理成Qwen2-VL需要的json格式,下面是目标的格式:
[
{
"id": "identity_1",
"conversations": [
{
"role": "user",
"value": "图片路径"
},
{
"role": "assistant",
"value": "LaTex公式"
}
]
},
...
]
下面进行简单解释
- id:数据对的编号
- conversations:人类与LLM的对话,类型是列表
- role:角色,user代表人类,assistant代表模型
- content:聊天发送的内容,其中user的value是图片路径,assistant的回复是LaTex公式
实现脚本 ImageDataCSVToJson.py
import pandas as pd # 导入pandas库,用于数据处理和CSV文件操作
import json # 导入json库,用于处理JSON数据
SAVED_CSV_PATH = "../../data/LaTeX_OCR/LaTeX_OCR_dataset.csv" # 定义CSV文件的路径,该文件包含LaTeX_OCR数据集的图片路径和文本
SAVED_TRAIN_JSON_PATH = "../../data/LaTeX_OCR/LaTeX_OCR_dataset_train_vl.json" # 定义训练集JSON文件的保存路径
SAVED_VALIDATION_JSON_PATH = "../../data/LaTeX_OCR/LaTeX_OCR_dataset_validation_vl.json" # 定义验证集JSON文件的保存路径
# 载入CSV文件
df = pd.read_csv(SAVED_CSV_PATH) # 使用pandas的read_csv函数读取CSV文件,并存储在DataFrame中
conversations = [] # 初始化一个空列表,用于存储对话数据
# 添加图片对话数据
for i in range(len(df)):
conversations.append({
"id": f"identity_{i + 1}", # 为每个对话分配一个唯一的ID
"conversations": [ # 定义对话内容的列表
{
"role": "user", # 指定消息来源为用户
"value": f"{df.iloc[i]['image_path']}" # 用户发送的消息,包含图片路径
},
{
"role": "assistant", # 指定消息来源为助手
"value": str(df.iloc[i]['text']) # 助手的回复,即图片的文本描述
}
]
})
# 将对话数据拆分为训练集和验证集
train_conversations = conversations[:-4] # 获取除了最后4条之外的所有对话作为训练集
val_conversations = conversations[-4:] # 获取最后4条对话作为验证集
# 保存训练集到JSON文件
with open(SAVED_TRAIN_JSON_PATH, 'w', encoding='utf-8') as f:
json.dump(train_conversations, f, ensure_ascii=False, indent=2) # 使用json.dump函数将训练集对话数据转换为JSON格式并写入文件
# 保存验证集到JSON文件
with open(SAVED_VALIDATION_JSON_PATH, 'w', encoding='utf-8') as f:
json.dump(val_conversations, f, ensure_ascii=False, indent=2) # 使用json.dump函数将验证集对话数据转换为JSON格式并写入文件
生成的 json 文件内容如下:
最后训练图片下载和json 数据工程结构如下
总之,数据集下载与处理方式
主要做了四件事情:
- 通过 Modelscope下载 LaTex_OCR 数据集
- 加载数据集,将图像保存到本地
- 将图像路径和描述文本转换为一个csv文件
- 将csv文件转换为json文件
2、Qwen2-VL 模型下载、训练、和测试
2.1 Qwen2-VL 模型
Qwen/Qwen2-VL-2B-Instruct网址:魔搭社区汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct
这里还用到 SwanLab 进行数据可视化展示,所以需要注册安装 SwanLab
SwanLab是一个类似Tensorboard的开源训练图表可视化库,有着更轻量的体积与更友好的API。除了能记录指标,还能自动记录训练的logging、硬件环境、Python环境、训练时间等信息。
Swanlab 官网:https://swanlab.cn/
Swanlab Github :https://github.com/SwanHubX/SwanLab
注意:记得 pip install swanlab 安装工具
(如果不想使用的话,把train训练脚本中对应 Swanlab 的相关代码删除即可)
2.2 Qwen2-VL 模型下载、以及集合图片数据进行训练
实现脚本 Qwen2VLDownloadToTrain.py
import os
import torch # 导入PyTorch库,用于深度学习模型的构建和训练
from datasets import Dataset # 从datasets库中导入Dataset类,用于处理数据集
from modelscope import snapshot_download, AutoTokenizer # 从modelscope库中导入snapshot_download和AutoTokenizer,用于模型和分词器的加载
from swanlab.integration.transformers import SwanLabCallback # 从swanlab库中导入SwanLabCallback,用于SwanLab的回调功能
from qwen_vl_utils import process_vision_info # 从qwen_vl_utils库中导入process_vision_info,用于处理视觉信息
from peft import LoraConfig, TaskType, get_peft_model, PeftModel # 从peft库中导入LoraConfig、TaskType、get_peft_model和PeftModel,用于模型的配置和训练
from transformers import (
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
Qwen2VLForConditionalGeneration,
AutoProcessor,
) # 从transformers库中导入TrainingArguments、Trainer、DataCollatorForSeq2Seq、Qwen2VLForConditionalGeneration和AutoProcessor,用于模型训练和数据处理
import swanlab # 导入swanlab库,用于SwanLab的集成
import json # 导入json库,用于处理JSON数据
# 定义模型和路径变量
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" # 定义模型名称
MODEL_REVISION_NAME = "master" # 定义模型版本名称
CACHE_DIR = "../../data" # 定义缓存目录
MODEL_SAVED_PATH = CACHE_DIR + "/" + MODEL_NAME # 定义模型保存路径
IMAGE_SAVED_TRAIN_JSON_PATH = "../../data/LaTeX_OCR/LaTeX_OCR_dataset_train_vl.json" # 定义训练集图像JSON文件路径
IMAGE_SAVED_VALIDATION_JSON_PATH = "../../data/LaTeX_OCR/LaTeX_OCR_dataset_validation_vl.json" # 定义验证集图像JSON文件路径
MODEL_DATA_VL_TRAIN_JSON_PATH = CACHE_DIR + "/Qwen/data_vl_train.json" # 定义训练数据JSON文件路径
MODEL_DATA_VL_VALIDATION_JSON_PATH = CACHE_DIR + "/Qwen/data_vl_test.json" # 定义验证数据JSON文件路径
MODEL_TRAIN_RESULT_OUTPUT_PATH = "../../output/Qwen2-VL-2B/" # 定义模型训练结果输出路径
MAX_LENGTH = 8192 # 定义最大序列长度
PROMPT = "你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式。" # 定义提示信息
def process_func(example):
"""
将数据集进行预处理,包括文本和视觉信息的处理。
"""
input_ids, attention_mask, labels = [], [], [] # 初始化输入ID、注意力掩码和标签列表
conversation = example["conversations"] # 获取对话内容
image_file_path = conversation[0]["value"] # 从对话中提取图像文件路径
output_content = conversation[1]["value"] # 从对话中提取输出内容
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"{image_file_path}",
"resized_height": 500,
"resized_width": 100,
},
{"type": "text", "text": PROMPT},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) # 使用processor处理文本,应用聊天模板
image_inputs, video_inputs = process_vision_info(messages) # 处理视觉信息
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = {key: value.tolist() for key, value in inputs.items()} # 将tensor转换为list,方便后续拼接
instruction = inputs
response = tokenizer(f"{output_content}", add_special_tokens=False) # 使用tokenizer处理输出内容
input_ids = (
instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
) # 拼接输入ID
attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1] # 拼接注意力掩码
labels = (
[-100] * len(instruction["input_ids"][0])
+ response["input_ids"]
+ [tokenizer.pad_token_id]
) # 拼接标签
if len(input_ids) > MAX_LENGTH: # 如果超过最大长度,则进行截断
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
input_ids = torch.tensor(input_ids) # 将输入ID转换为tensor
attention_mask = torch.tensor(attention_mask) # 将注意力掩码转换为tensor
labels = torch.tensor(labels) # 将标签转换为tensor
inputs['pixel_values'] = torch.tensor(inputs['pixel_values']) # 将像素值转换为tensor
inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0) # 将图像网格尺寸转换为tensor并压缩维度
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,
"pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']} # 返回处理后的数据
def predict(messages, model):
"""
使用模型进行预测,生成文本。
"""
# 准备推理
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) # 使用processor处理文本,应用聊天模板
image_inputs, video_inputs = process_vision_info(messages) # 处理视觉信息
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda") # 将输入数据移动到CUDA设备
# 生成输出
generated_ids = model.generate(**inputs, max_new_tokens=MAX_LENGTH) # 使用模型生成输出ID
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
] # 修剪输出ID
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
) # 使用processor解码输出ID
return output_text[0] # 返回生成的文本
# 在modelscope上下载Qwen2-VL模型到本地目录下
model_dir = snapshot_download(MODEL_NAME, cache_dir=CACHE_DIR, revision=MODEL_REVISION_NAME)
# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained(MODEL_SAVED_PATH, use_fast=False, trust_remote_code=True) # 加载分词器
processor = AutoProcessor.from_pretrained(MODEL_SAVED_PATH) # 加载处理器
origin_model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_SAVED_PATH, device_map="auto",
torch_dtype=torch.bfloat16, trust_remote_code=True, )
origin_model.enable_input_require_grads() # 开启梯度检查点
# 读取训练集和验证集JSON文件,并保存到指定路径
with open(IMAGE_SAVED_TRAIN_JSON_PATH, 'r') as f:
data = json.load(f) # 读取训练集JSON文件
train_data = data # 获取训练数据
with open(MODEL_DATA_VL_TRAIN_JSON_PATH, "w") as f:
json.dump(train_data, f) # 保存训练数据到JSON文件
with open(IMAGE_SAVED_VALIDATION_JSON_PATH, 'r') as f:
data = json.load(f) # 读取验证集JSON文件
validation_data = data # 获取验证数据
with open(MODEL_DATA_VL_VALIDATION_JSON_PATH, "w") as f:
json.dump(validation_data, f) # 保存验证数据到JSON文件
# 处理数据集:读取json文件
train_ds = Dataset.from_json(MODEL_DATA_VL_TRAIN_JSON_PATH) # 从JSON文件加载训练数据集
train_dataset = train_ds.map(process_func) # 对训练数据集应用预处理函数
# 配置LoRA
config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # 设置任务类型为因果语言模型
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # 设置LoRA目标模块
inference_mode=False,
r=64, # 设置LoRA秩
lora_alpha=16, # 设置LoRA alpha,具体作用参见LoRA原理
lora_dropout=0.05, # 设置Dropout比例
bias="none", # 设置偏置类型
)
# 获取LoRA模型
train_peft_model = get_peft_model(origin_model, config) # 使用配置获取LoRA模型
# 配置训练参数
args = TrainingArguments(
output_dir=MODEL_TRAIN_RESULT_OUTPUT_PATH, # 设置输出目录
per_device_train_batch_size=4, # 设置每个设备的训练批次大小
gradient_accumulation_steps=4, # 设置梯度累积步数
logging_steps=10, # 设置日志记录步数
logging_first_step=10, # 设置首次日志记录步数
num_train_epochs=2, # 设置训练轮数
save_steps=100, # 设置保存模型的步数
learning_rate=1e-4, # 设置学习率
save_on_each_node=True, # 设置是否在每个节点上保存
gradient_checkpointing=True, # 设置是否使用梯度检查点
report_to="none", # 设置不向任何平台报告训练进度
)
# 设置SwanLab回调
swanlab_callback = SwanLabCallback(
project="Qwen2-VL-ft-latexocr", # 设置项目名称
experiment_name="7B-1kdata", # 设置实验名称
config={
"model": "https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct", # 设置模型链接
"dataset": "https://modelscope.cn/datasets/AI-ModelScope/LaTeX_OCR/summary", # 设置数据集链接
"github": "https://github.com/your_repository", # 设置GitHub链接
"model_id": MODEL_NAME, # 设置模型ID
"train_dataset_json_path": MODEL_DATA_VL_TRAIN_JSON_PATH, # 设置训练数据集JSON文件路径
"val_dataset_json_path": MODEL_DATA_VL_VALIDATION_JSON_PATH, # 设置验证数据集JSON文件路径
"output_dir": MODEL_TRAIN_RESULT_OUTPUT_PATH, # 设置输出目录
"prompt": PROMPT, # 设置提示信息
"train_data_number": len(train_ds), # 设置训练数据数量
"token_max_length": MAX_LENGTH, # 设置最大序列长度
"lora_rank": 64, # 设置LoRA秩
"lora_alpha": 16, # 设置LoRA alpha
"lora_dropout": 0.1, # 设置LoRA Dropout比例
},
)
# 配置Trainer
trainer = Trainer(
model=train_peft_model, # 设置模型
args=args, # 设置训练参数
train_dataset=train_dataset, # 设置训练数据集
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), # 设置数据收集器
callbacks=[swanlab_callback], # 设置回调函数
)
# 开启模型训练
trainer.train() # 训练模型
# ====================测试===================
# 配置测试参数
val_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # 设置任务类型为因果语言模型
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # 设置LoRA目标模块
inference_mode=True, # 设置为推理模式
r=64, # 设置LoRA秩
lora_alpha=16, # 设置LoRA alpha
lora_dropout=0.05, # 设置Dropout比例
bias="none", # 设置偏置类型
)
# 获取测试模型,从output_dir中获取最新的checkpoint
load_model_path = f"{MODEL_TRAIN_RESULT_OUTPUT_PATH}/checkpoint-{max([int(d.split('-')[-1]) for d in os.listdir(MODEL_TRAIN_RESULT_OUTPUT_PATH) if d.startswith('checkpoint-')])}"
print(f"load_model_path: {load_model_path}")
val_peft_model = PeftModel.from_pretrained(origin_model, model_id=load_model_path, config=val_config) # 加载测试模型
# 读取测试数据
with open(MODEL_DATA_VL_VALIDATION_JSON_PATH, "r") as f:
test_dataset = json.load(f) # 读取测试数据JSON文件
test_image_list = [] # 初始化测试图像列表
for item in test_dataset:
image_file_path = item["conversations"][0]["value"] # 获取图像文件路径
label = item["conversations"][1]["value"] # 获取标签
messages = [{
"role": "user",
"content": [
{
"type": "image",
"image": image_file_path,
"resized_height": 100,
"resized_width": 500,
},
{
"type": "text",
"text": PROMPT,
}
]}]
response = predict(messages, val_peft_model) # 使用预测函数生成响应
print(f"predict:{response}")
print(f"gt:{label}\n")
test_image_list.append(swanlab.Image(image_file_path, caption=response)) # 将图像和响应添加到测试图像列表
swanlab.log({"Prediction": test_image_list}) # 使用SwanLab记录预测结果
# 停止SwanLab记录,需要调用swanlab.finish()
swanlab.finish() # 结束SwanLab记录
训练过程如下图:
训练测试集结果:
SwanLab 进行训练结果展示如下图:
2.3 加载上一步训练 lora 微调后的保存的结果模型,然后进行图片识别测试
实现脚本 TestQwen2VLPracticeResult.py
from transformers import Qwen2VLForConditionalGeneration, \
AutoProcessor # 从transformers库中导入Qwen2VLForConditionalGeneration模型和AutoProcessor处理器类
from qwen_vl_utils import process_vision_info # 从qwen_vl_utils库中导入process_vision_info函数,用于处理视觉信息
from peft import PeftModel, LoraConfig, TaskType # 从peft库中导入PeftModel、LoraConfig和TaskType,用于模型的LoRA配置和训练
import time # 导入time模块,用于时间操作
# 定义模型和路径变量
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" # 定义模型名称
CACHE_DIR = "../../data" # 定义缓存目录
MODEL_SAVED_PATH = CACHE_DIR + "/" + MODEL_NAME # 定义模型保存路径
MODEL_TRAIN_RESULT_OUTPUT_PATH = "../../output/Qwen2-VL-2B/" # 定义模型训练结果输出路径
LORA_MODEL_NAME = "checkpoint-124" # 定义LoRA模型的检查点名称
TEST_IMAGE_PATH = "../../data/LaTeX_OCR/images/2.jpg" # 定义测试图像路径
PROMPT = "你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式。" # 定义提示信息
config = LoraConfig( # 创建LoRA配置
task_type=TaskType.CAUSAL_LM, # 设置任务类型为因果语言模型
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # 设置LoRA目标模块
inference_mode=True, # 设置为推理模式
r=64, # 设置LoRA秩
lora_alpha=16, # 设置LoRA alpha,具体作用参见LoRA原理
lora_dropout=0.05, # 设置Dropout比例
bias="none", # 设置偏置类型
)
# 加载模型
model = Qwen2VLForConditionalGeneration.from_pretrained( # 从预训练路径加载Qwen2VLForConditionalGeneration模型
MODEL_SAVED_PATH, torch_dtype="auto", device_map="auto"
)
model = PeftModel.from_pretrained( # 加载PeftModel
model, model_id=f"{MODEL_TRAIN_RESULT_OUTPUT_PATH}/{LORA_MODEL_NAME}", config=config
)
processor = AutoProcessor.from_pretrained(MODEL_SAVED_PATH) # 从预训练路径加载AutoProcessor处理器
# 记录开始时间
start_time = time.time() # 使用time模块记录开始时间
# 定义用户消息,包含图像和文本提示
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": TEST_IMAGE_PATH, # 指定测试图像路径
"resized_height": 100, # 指定图像调整后的高度
"resized_width": 500, # 指定图像调整后的宽度
},
{"type": "text", "text": f"{PROMPT}"}, # 指定提示信息
],
}
]
# 准备推理
text = processor.apply_chat_template( # 使用处理器应用聊天模板
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages) # 处理视觉信息
inputs = processor( # 使用处理器处理输入
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda") # 将输入数据移动到CUDA设备
# 推理:生成输出
generated_ids = model.generate(**inputs, max_new_tokens=8192) # 使用模型生成输出ID
generated_ids_trimmed = [ # 修剪输出ID
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode( # 使用处理器批量解码输出ID
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# 记录结束时间并计算执行时间
end_time = time.time() # 使用time模块记录结束时间
execution_time = end_time - start_time # 计算执行时间
print(f"Execution time: {execution_time}, output:{output_text[0]}") # 打印执行时间和输出结果
测试的图片(也可以测试其他的图片):
执行脚本的测试结果:
LaTex 公式编辑器网站上验证的结果:
\rho _ { L } ( q ) = \sum _ { m = 1 } ^ { L } P _ { L } ( m ) \frac { 1 } { q ^ { m - 1 } } .
Qwen2VL下载训练测试的工程结构如下图:
六、工程下载
工程下载地址:https://download.csdn.net/download/u014361280/90154588
工程中的一些图片和模型等 data 资源被删掉了
大家可以按下面执行顺序,data 资源,和 output 输出 会自动下载和生成
1、scripts\ImageDataHandler\ImageDataDownloadThenDataToCSV.py
2、scripts\ImageDataHandler\ImageDataCSVToJson.py
3、scripts\QWenVL\Qwen2VLDownloadToTrain.py
4、scripts\QWenVL\TestQwen2VLPracticeResult.py
附录
1、如果运行代码中,报如下类似的错,对应的 pip install 安装对应的模块包即可