当前位置: 首页 > article >正文

flash-attention保姆级安装教程

FlashAttention安装教程

FlashAttention 是一种高效且内存优化的注意力机制实现,旨在提升大规模深度学习模型的训练和推理效率。

  • 高效计算:通过优化 IO 操作,减少内存访问开销,提升计算效率。

  • 内存优化:降低内存占用,使得在大规模模型上运行更加可行。

  • 精确注意力:保持注意力机制的精确性,不引入近似误差。

  • FlashAttention-2 是 FlashAttention 的升级版本,优化了并行计算策略,充分利用硬件资源。改进了工作负载分配,进一步提升计算效率。

  • FlashAttention-3:FlashAttention-3 是专为 Hopper GPU(如 H100)优化的版本,目前处于 Beta 测试阶段。


常见问题:
安装成功后,实际模型代码运行时报错未安装,核心原因就是cxx11abiFALSE这个参数,表示该包在构建时不启用 C++11 ABI。
必须开启不使用才行。否则报错如下:
ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn.


最佳安装步骤(方法1)

  1. 安装依赖
    • 基础环境:cuda12.1、nvcc.
    • 安装python,示例3.10。
    • 安装PyTorch,示例orchtorch2.3.0; torchvision0.18.0
    • ninja Python 包
  2. 获取releases对应的whl包
    - 地址:https://github.com/Dao-AILab/flash-attention/releases
    - 按照系统环境选whl
    在这里插入图片描述
    3. 我的环境对应的包是:flash_attn-2.7.2.post1+cu12torch2.3cxx11abiTRUE-cp310-cp310-linux_x86_64.whl,解释如下:
    • flash_attn: 包的名称,表示这个 Wheel 文件是 flash_attn 包的安装文件。
    • 2.7.2.post1: 包的版本号,遵循 PEP 440 版本规范。
      • 2.7.2: 主版本号,表示这是 flash_attn 的第 2.7.2 版本。
      • post1: 表示这是一个“后发布版本”(post-release),通常用于修复发布后的某些问题。
    • +cu12torch2.3cxx11abiFALSE: 构建标签,表示该 Wheel 文件是在特定环境下构建的。
      • cu12: 表示该包是针对 CUDA 12 构建的。
      • torch2.3: 表示该包是针对 PyTorch 2.3 构建的。
      • cxx11abiFALSE: 表示该包在构建时不启用 C++11 ABI(Application Binary Interface)。如果安装包后不识别,就要选为False的版本。
    • cp310: Python 版本的标签,表示该包是为 Python 3.10 构建的。
      • cp310: 是 cpython 3.10 的缩写,表示该包适用于 CPython 解释器的 3.10 版本。
    • linux_x86_64: 平台标签,表示该包是为 Linux 操作系统和 x86_64 架构(即 64 位 Intel/AMD 处理器)构建的。
    • .whl: 文件扩展名,表示这是一个 Python Wheel 文件。Wheel 是 Python 的一种二进制分发格式,用于快速安装包。

如何安装

可以使用 pip 安装这个 Wheel 文件:

pip install flash_attn-2.7.2.post1+cu12torch2.3cxx11abiTRUE-cp310-cp310-linux_x86_64.whl --no-build-isolation


常规安装步骤(方法二)

  1. 安装依赖

    • CUDA 工具包或 ROCm 工具包
    • PyTorch 1.12 及以上版本
    • packagingninja Python 包
    pip install packaging ninja
    
  2. 安装 FlashAttention

    # 后面--no-build-isolation参数是为了pip 会直接在当前环境中构建包,使用当前环境中已安装的依赖项。
    # 如果当前环境缺少构建所需的依赖项,构建过程可能会失败。
    pip install flash-attn --no-build-isolation
    

    或从源码编译:

    # 下载源码后,进行编译
    cd flash-attention
    python setup.py install
    
  3. 运行测试

    export PYTHONPATH=$PWD
    pytest -q -s test_flash_attn.py
    
  4. 补充说明

    4.1 上面运行时,建议设置参数MAX_JOBS,限制最大进程数,不然系统容易崩。本人在docker下安装,直接干重启了,所以建议如下方式运行:

    MAX_JOBS=4 pip install flash-attn --no-build-isolation
    

    4.2 如果运行时会出现警告且推理速度依旧很慢,需要继续从源码安装rotary和layer_norm,cd到源码的那两个文件夹,执行 python setup.py install进行安装,如果命令报错弃用,可能要用easy_install命令。
    在这里插入图片描述

接口使用

import flash_attn_interface
flash_attn_interface.flash_attn_func()

硬件支持

NVIDIA CUDA 支持

  • 支持 GPU:Ampere、Ada 或 Hopper 架构 GPU(如 A100、RTX 3090、RTX 4090、H100)。
  • 数据类型:FP16 和 BF16。
  • 头维度:支持所有头维度,最大至 256。

AMD ROCm 支持

  • 支持 GPU:MI200 或 MI300 系列 GPU。
  • 数据类型:FP16 和 BF16。
  • 后端:支持 Composable Kernel (CK) 和 Triton 后端。

性能优化

Triton 后端

Triton 后端的 FlashAttention-2 实现仍在开发中,目前支持以下特性:

  • 前向和反向传播:支持因果掩码、变长序列、任意 Q 和 KV 序列长度、任意头大小。
  • 多查询和分组查询注意力:目前仅支持前向传播,反向传播支持正在开发中。

性能改进

  • 并行编译:使用 ninja 工具进行并行编译,显著减少编译时间。
  • 内存管理:通过设置 MAX_JOBS 环境变量,限制并行编译任务数量,避免内存耗尽。

结论

FlashAttention 系列通过优化计算和内存使用,显著提升了注意力机制的效率。无论是研究人员还是工程师,都可以通过本文提供的安装和使用指南,快速上手并应用于实际项目中。随着 FlashAttention-3 的推出,针对 Hopper GPU 的优化将进一步推动大规模深度学习模型的发展。

参考链接

  • FlashAttention 源码


http://www.kler.cn/a/460350.html

相关文章:

  • 【开源免费】基于SpringBoot+Vue.JS保密信息学科平台(JAVA毕业设计)
  • [算法] [leetcode-324] 摆动排序 II
  • gitlab的搭建及使用
  • _使用CLion的Vcpkg安装SDL2,添加至CMakelists时报错,编译报错
  • openEuler ARM使用vdbench50407
  • 人工智能机器学习基础篇】——深入详解强化学习 基础知识,理解马尔可夫决策过程(MDP)、策略、价值函数等关键概念
  • 脚本方式 迁移 老GITLAB项目到新GITLAB
  • 前端vue+el-input实现输入框中文字高亮标红效果(学习自掘金博主文章)
  • 服务器系统维护与安全配置
  • 黑马商城:MybatisPlus
  • img上的title属性和alt属性的区别是什么?
  • Oracle 数据库 dmp文件从高版本导入低版本的问题处理
  • C++ 环境搭建 - 安装编译器、IDE选择
  • WebRTC音视频通话系统需求(项目预算)
  • ffmpeg 编译+ libx264
  • Golang 的AI 框架库
  • Windows电脑带有日历的桌面备忘记事工具
  • shell脚本的使用
  • 【基础还得练】EM算法中的E
  • 【Qt】信号和槽机制
  • 【MyBatis-Plus】让 MyBatis 更简单高效
  • 【Kafka 消息队列深度解析与应用】
  • 基于zynq在linux下的HDMI实战
  • labelme2yolov8-seg 草稿()
  • 头歌python:多进程和多线程
  • 年会头投票小游戏