当前位置: 首页 > article >正文

复现OpenVLA:开源的视觉-语言-动作模型及原理详解

复现OpenVLA:开源的视觉-语言-动作模型及原理详解

  • 1. 摘要
  • 2. 引言
  • 3. 相关工作
  • 4. 模型结构
    • 4.1 模视觉-语言模型VLM
    • 4.2 训练流程
    • 4.3 图像分辨率
    • 4.4 微调视觉编码器
    • 4.5 训练轮数
    • 4.6 学习率
    • 4.7 训练细节
    • 4.8 参数高效微调
  • 5. 复现
    • 5.1 下拉代码
    • 5.2 安装环境依赖
      • 5.2.1 创建conda环境
      • 5.2.2 安装torch
      • 5.2.3 安装openvla repo
      • 5.2.4 安装其它依赖
    • 5.3 加载预训练模型推理

开篇总结:OpenVLA是第一个 开源支持高效微调参数量比RT-2-X更 的通用VLA。

1. 摘要

在这里插入图片描述

OpenVLA,一个7 B参数的开源视觉语言动作模型(VLA),在Open X-Embodiment数据集的970 k机器人集上进行了训练。OpenVLA为通用机器人操作策略设定了一个新的艺术状态。它支持开箱即用控制多个机器人,并可以通过参数高效的微调快速适应新的机器人领域。OpenVLA权重和PyTorch训练管道是完全开源的,可以从HuggingFace下载和微调模型。
在互联网规模的视觉语言数据和多样化的机器人演示的组合上预先训练的大型策略有可能改变我们教机器人新技能的方式:与其从头开始训练新的行为,我们可以微调这样的视觉-语言-动作(VLA)模型,以获得用于视觉控制的鲁棒的、可推广的策略。
OpenVLA构建在Llama 2语言模型之上,并结合了视觉编码器,融合了DINOv2SigLIP的预训练特征。
作为增加的数据多样性和新模型组件的产物,OpenVLA在通才操作方面表现出了强大的结果,在29个任务和多个机器人实施例中,绝对任务成功率超过RT-2-X(55 B)等封闭模型16.5%参数减少了7倍
实验进一步表明,我们可以有效地微调OpenVLA以适应新的设置,在涉及多个对象和强语言基础能力的多任务环境中具有特别强的泛化结果,并且比从头开始的模仿学习方法(如扩散策略)高出20.4%
我们还探索计算效率,OpenVLA可以通过主流的低秩自适应方法微调消费者GPU,并通过量化有效地服务,而不会影响下游的成功率。

2. 引言

用于机器人操作的学习策略的一个关键弱点是它们无法具有泛化训练数据以外的能力:它们对场景干扰或新物体缺乏鲁棒性,并且难以执行看不见的任务指令。但是,除了机器人技术之外,现有的视觉和语言基础模型,如CLIP、SigLIP和Llama 2,甚至能够进行这些类型的泛化,这些泛化都源于互联网规模的预训练数据集所捕获的先验知识。然而复制这种规模的机器人预训练仍然是一个公开的挑战—即使是最大的机器人操作数据集也只有10万到100万个示例—这种不平衡意味着一个机会:使用现有的视觉和语言基础模型作为训练机器人策略的核心构建块,这些策略可以推广到新对象,场景和任务。
为了实现这一目标,现有的工作已经探索了将预训练的语言和视觉语言模型进行整合用于机器人表征学习。
然而,有两个关键原因阻止了现有VLA的广泛使用:
1)当前模型是封闭的,无法开源访问,对模型架构,训练过程和数据混合的可见性有限;
2)现有工作没有提供部署和适应新机器人,环境和任务的VLA的最佳实践-特别是在商品硬件上(例如,消费级GPU)。我们认为,为了为未来的研究和开发奠定坚实的基础,机器人技术需要支持有效微调和适应的开源通用VLA,类似于围绕开源语言模型的现有生态系统。
为此,我们引入了OpenVLA,这是一个7B参数的开源VLA,它为通用机器人操作策略建立了一个新的SOTA水平。OpenVLA由一个预训练的视觉-条件语言模型主干组成,该模型以多个粒度捕获视觉特征,并在Open-X的970k机器人操作轨迹的大型多样化数据集上进行了微调。
第一个证明了利用低秩自适应[LoRA]和模型量化的计算高效微调方法的有效性,以促进在消费级GPU而不是大型服务器节点上适应OpenVLA模型,而不会影响性能。

3. 相关工作

视觉条件语言模型(VLM)在互联网规模的数据上进行训练,以从输入图像和语言提示中生成自然语言,已被用于从视觉问答[28-31]到对象定位[32,33]的无数应用。推动最近VLM的关键进步之一是模型架构,它将来自预训练的视觉编码器的特征预训练语言模型连接起来,直接建立在计算机视觉和自然语言建模的进步基础上,以创建强大的多模态模型
OpenVLA采用了一种更端到端的方法,直接微调VLM,通过将它们视为语言模型词汇表中的标记来生成机器人动作。
视觉语言动作模型(VLA),将机器人控制动作直接融合到VLM主干中。这有三个主要好处:(1)它在大型互联网规模的视觉语言数据集上执行预训练的视觉和语言组件的对齐,(2)使用通用架构,而不是为机器人控制定制,允许我们利用现有VLM训练的可扩展基础设施,并扩展到以最少的代码修改训练十亿参数策略(3)它为机器人技术从VLMs的快速改进中受益提供了直接途径。

4. 模型结构

4.1 模视觉-语言模型VLM

最新的VLM的架构由三个主要部分组成(见下图):
(1)将图像输入映射到许多“image patch embeddings”的视觉编码器;
(2)将视觉编码器的输出嵌入并将其映射到语言模型的输入空间的projector;
(3)大型语言模型(LLM)骨干。
在VLM训练过程中,该模型是端到端训练的。

在这里插入图片描述
Prismatic遵循与上述相同的标准架构,具有600 M参数的视觉编码器,小型2层MLP的projector和7B参数的Llama 2语言模型骨干。值得注意的是,Prismatic使用两部分视觉编码器,由预训练的SigLIP和DinoV 2模型组成。输入图像patch分别通过两个编码器,得到的特征向量按通道连接。
ps:与更常用的视觉编码器(如CLIP或仅SigLIP编码器)相比,DinoV 2功能的添加已被证明有助于改进空间推理,这对机器人控制特别有帮助。

4.2 训练流程

为了训练OpenVLA,对主干进行微调,以进行机器人动作预测。将动作预测问题表述为“视觉语言”任务,其中输入观察图像和自然语言任务指令映射到预测的机器人动作串

4.3 图像分辨率

输入图像的分辨率对VLA训练的计算要求有显著影响,因为更高分辨率的图像导致更多的image patch tokens,从而导致更长的上下文长度,从而二次增加训练计算。
比较了输入为224 × 224px和384 × 384px的VLA,但在评估中没有发现性能差异,而后者的训练时间长3倍。因此,我们为最终的OpenVLA模型选择了224 × 224px的分辨率。
ps:请注意,在许多VLM基准测试中,提高分辨率确实可以提高性能,但在这里还没有看到VLA的这种趋势。

4.4 微调视觉编码器

之前对VLM的研究发现,在VLM训练期间冻结视觉编码器通常会带来更高的性能。直观地,冻结视觉编码器可以更好地保留从其互联网规模的预训练中学习到的鲁棒特征。然而,我们发现在VLA训练期间微调视觉编码器对于良好的VLA性能至关重要。我们假设,预先训练的视觉骨干可能无法捕捉到足够的关于场景重要部分的细粒度空间细节,以实现精确的机器人控制。

4.5 训练轮数

典型的LLM或VLM训练在其训练数据集中最多运行一个或两个epoch。相比之下,我们发现VLA训练通过训练数据集进行更多次的训练是很重要的,真实的机器人性能不断提高,直到训练动作令牌准确率超过95%。我们的最终训练运行通过其训练数据集完成了27个epoch。

4.6 学习率

我们在VLA训练的多个数量级上扫描了学习率,并使用2e-5的固定学习率(与VLM预训练期间使用的学习率相同)获得了最佳结果。没有发现学习率热身可以带来好处。

4.7 训练细节

最终的OpenVLA模型在64个A100 GPU的集群上训练14天,或总共21,500个A100小时,batch size 为2048。
在推理过程中,OpenVLA在以float16精度加载时需要15GB的GPU内存(即,没有量化),在一个NVIDIA RTX 4090 GPU上以大约6Hz的频率运行(没有编译、推测解码或其他推理加速技巧)。
我们可以在推理过程中通过量化进一步减少OpenVLA的内存占用,而不会影响实际机器人任务的性能。

4.8 参数高效微调

具体来说,我们比较了以下微调方法:FFT完全微调在微调期间更新所有权重,last layer only仅微调OpenVLA的Transformer骨干的最后一层和token embedding矩阵,frozen vision结视觉编码器,但微调所有其他权重,sandwich fine-tuning不冻结视觉编码器,token embedding矩阵和最后一层,LoRA使用低秩自适应技术,具有多个秩值r,应用于模型的所有线性层。

在这里插入图片描述

5. 复现

5.1 下拉代码

git clone https://gitcode.com/gh_mirrors/op/openvla.git
# 下面的也行
git clone https://github.com/openvla/openvla.git

注意:
1.vla-scripts/是有关OpenVLA模型的完整培训和验证脚本。
2.scripts/主要是原始(基础)prismatic-vlms存储库的遗留物,支持训练和评估视觉条件语言模型;注意:虽然可以使用此仓库来训练VLMs和VLAs,尝试使用现有的OpenVLA模型生成语言(通过scripts/generate.py)是行不通的(因为只训练当前的OpenVLA模型来生成动作,并且只训练动作)。
在这里插入图片描述

5.2 安装环境依赖

5.2.1 创建conda环境

conda create -n openvla python=3.10 -y
conda activate openvla

5.2.2 安装torch

这个存储库是使用Python 3.10构建的,但应该与任何Python>=3.8向后兼容。需要PyTorch 2.2*

conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y

5.2.3 安装openvla repo

cd openvla
pip install -e .

5.2.4 安装其它依赖

# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention)
#   =>> If you run into difficulty, try `pip cache remove flash_attn` first
pip install packaging ninja
ninja --version; echo $?  # Verify Ninja --> should return exit code "0"
pip install "flash-attn==2.5.5" --no-build-isolation

5.3 加载预训练模型推理

# 安装最少依赖项 (`torch`, `transformers`, `timm`, `tokenizers`, 等)
# pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
import torch

# 加载处理器与VLA
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    attn_implementation="flash_attention_2",  # 需要`flash_attn`
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to("cuda:0")

# 获取图像输入并格式化提示
image: Image.Image = get_from_camera(...)  # 假设get_from_camera是获取摄像头图像的函数
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"

# 预测动作(7自由度;对BridgeData V2解归一化)
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

# 执行动作...
robot.act(action, ...)

http://www.kler.cn/news/307297.html

相关文章:

  • 【Go开发】Go语言结构体,与java类不一样的定义方式
  • 推荐|基于springBoot智能推荐的卫生健康系统设计与实现(源码+论文+数据库)
  • 【附源码】用Python开发一个音乐下载工具,并打包EXE文件,所有音乐都能搜索下载!
  • el-table 的单元格 + 图表 + 排序
  • 动手学深度学习(pytorch土堆)-03常见的Transforms
  • 图论篇--代码随想录算法训练营第五十六天打卡| 108. 冗余连接,109. 冗余连接II
  • 【SQL】百题计划:SQL排序Order by的使用。
  • Flutter Error: Type ‘UnmodifiableUint8ListView‘ not found
  • 刷题DAY36
  • 初中生物--5.单细胞生物
  • VuePress搭建文档网站/个人博客(详细配置)主题配置-导航栏配置
  • 【开源免费】基于SpringBoot+Vue.JS企业客户管理系统(JAVA毕业设计)
  • Linux命令:文本处理工具sed详解
  • django中F()和Q()的用法
  • 保姆级离线+windows环境+大模型前端UI安装(二)
  • 基于Spring Boot的停车场管理系统的设计与实现
  • 【STL】 set 与 multiset:基础、操作与应用
  • Vue路由配置、网络请求访问框架项目、element组件介绍学习
  • 数据库连接池与Druid【后端 16】
  • STM32 HAL freertos零基础(十)软件定时器
  • Renesas R7FA8D1BH (Cortex®-M85)控制ISLS29035
  • Unity-Transform类-父子关系
  • 五、(JS)window中的定时器
  • PhotoZoom Pro / Classic 9.0.2激活版安装激活图文教程
  • 栈与队列(c语言实现)
  • GAMES101(2~3作业)
  • 【系统架构设计师】单例模式(Singleton Pattern)
  • PCIe进阶之TL:Common Packet Header Fields TLPs with Data Payloads Rules
  • MYSQL数据库基础篇——MYSQL的安装与使用
  • Go中如何找到哪里依赖了某个module,如何找到所有module的最大GoVersion