当前位置: 首页 > article >正文

FlashMLA(DeepSeek开源周,第一个框架):含源码分析

1. 概述

FlashMLA 是由 DeepSeek 原创开发的一种深度学习框架,专门用于加速多头注意力机制(MLA)架构的推理过程。它通过优化内存管理和计算效率,显著提升了模型在高性能 GPU 上的推理速度。FlashMLA 主要适用于 DeepSeek 的架构模型(如 DeepSeek-R1 和 DeepSeek-V3),并专为 NVIDIA H 系列显卡(如 H800 SXM5)进行了深度优化。

2.开源与代码分析

目录如下:

flash_api.cpp
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
​
#include <cutlass/fast_math.h>
​
#include "flash_mla.h"
#include "static_switch.h"
​
/**
 * 宏定义:CHECK_DEVICE、CHECK_SHAPE 和 CHECK_CONTIGUOUS 是用于验证输入张量的设备、形状和存储格式的便捷工具。
硬件检查:通过 at::cuda::getCurrentDeviceProperties 获取 GPU 属性,确保代码在兼容的硬件(如 SM90 架构)上运行。
张量操作:view 和 reshape 用于调整输入张量的维度,以适应不同的计算需求。
CUDA 内核调用:run_mha_fwd_splitkv_mla 是核心的 CUDA 内核函数,负责执行实际的注意力计算。
调度元数据:tile_scheduler_metadata 和 num_splits 用于控制 GPU 的计算调度,优化计算资源的分配。
 */
/**
 * get_mla_metadata 函数
根据输入的序列长度和硬件配置(如 GPU 架构和线程块大小),计算并返回适用于多输入注意力(MLA)的调度元数据和分块信息。
主要是为了在大规模或分布式计算场景中优化 GPU 资源的分配和计算调度。
mha_fwd_kvcache_mla 函数
实现多头注意力(MHA)的前向计算,支持分块和缓存优化,特别适用于大规模模型训练中的多输入场景。
使用了各种优化手段(如峰值归一化、分组计算、调度机制等),以提高计算效率和性能。
Pybind11 绑定
将 C++ 函数绑定到 Python,使得 Python 用户能够通过 PyTorch 扩展调用高性能的 C++ 编写的注意力计算函数。
 */
// 定义检查设备的宏
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
// 定义检查张量形状的宏
#define CHECK_SHAPE(x, ...) \
    TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// 定义检查张量是否连续的宏
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
​
// 获取多输入注意力(MLA)的元数据信息,用于调度和优化计算
std::vector<at::Tensor> get_mla_metadata(
    at::Tensor &seqlens_k,               // 输入序列长度张量(形状为 [batch_size])
    const int num_heads_per_head_k,      // 每个 Query 头对应的 Key 头的数量
    const int num_heads_k                // Key 头的总数量
) {
    // 这些常量对应 GPU 中线程块的大小(用于调度)
    static constexpr int block_size_m = 64;
    static constexpr int block_size_n = 64;
    static constexpr int fixed_overhead_num_blocks = 5;
​
    CHECK_DEVICE(seqlens_k);                     // 确保输入在 CUDA 设备上
    TORCH_CHECK(seqlens_k.is_contiguous());      // 确保输入是连续的
    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); // 确保数据类型是 int32
​
    int batch_size = seqlens_k.size(0);          // 获取批处理大小
    int *seqlens_k_ptr = seqlens_k.data_ptr<int>(); // 获取序列长度的指针
​
    auto options = seqlens_k.options();          // 获取张量的选项(如数据类型和设备)
​
    // 获取当前 GPU 的属性
    auto dprops = at::cuda::getCurrentDeviceProperties();
    int sm_count = dprops->multiProcessorCount;  // 当前 GPU 的多处理器数量(SM 数量)
​
    // 计算每个 SM 分配的线程块的数量
    int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
​
    // 初始化存储调度元数据和分块数的张量
    auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
    auto num_splits = torch::empty({batch_size + 1}, options);
​
    int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
    int *num_splits_ptr = num_splits.data_ptr<int>();
​
    // 确保设备和 CUDA 流一致
    at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
    auto stream = at::cuda::getCurrentCUDAStream().stream();
​
    // 设置调度参数
    Mla_metadata_params params = {};
    params.seqlens_k_ptr = seqlens_k_ptr;                // 序列长度指针
    params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; // 元数据指针
    params.num_splits_ptr = num_splits_ptr;              // 分块数指针
    params.batch_size = batch_size;                      // 批处理大小
    params.block_size_n = block_size_n;                  // 线程块大小(n 方向)
    params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; // 固定开销的线程块数量
    params.num_sm_parts = num_sm_parts;                  // SM 的划分数量
​
    // 调用 CUDA 函数计算元数据
    get_mla_metadata_func(params, stream);
​
    return {tile_scheduler_metadata, num_splits}; // 返回调度元数据和分块数
}
​
// 多头注意力(MHA)前向计算函数,使用分块和缓存优化(适用大模型训练中的多输入场景)
std::vector<at::Tensor> mha_fwd_kvcache_mla(
    at::Tensor &q,                               // 查询张量(形状为 [batch_size, seqlen_q, num_heads, head_size])
    const at::Tensor &kcache,                    // Key 缓存(形状为 [num_blocks, page_block_size, num_heads_k, head_size])
    c10::optional<const at::Tensor> &vcache_,    // Value 缓存(形状为 [num_blocks, page_block_size, num_heads_k, head_size_v])
    const int head_size_v,                       // Value 的头大小
    const at::Tensor &seqlens_k,                 // Key 的序列长度(形状为 [batch_size])
    const at::Tensor &block_table,               // 分块表(形状为 [batch_size, max_num_blocks_per_seq])
    const float softmax_scale,                   // Softmax 缩放因子
    bool is_causal,                              // 是否启用因果机制(避免未来时间步的干扰)
    const at::Tensor &tile_scheduler_metadata,   // 调度元数据(形状为 [num_sm_parts, TileSchedulerMetaDataSize])
    const at::Tensor &num_splits                 // 分块数(形状为 [batch_size +1])
) {
    auto dprops = at::cuda::getCurrentDeviceProperties();
    bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // 检查是否是 SM90 架构的 GPU
​
    TORCH_CHECK(is_sm90); // 必须是 SM90 架构的 GPU
​
    at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; // 如果 Value 缓存存在,使用它;否则使用 Key 缓存
​
    // 检查数据类型一致性
    auto q_dtype = q.dtype();
    TORCH_CHECK(q_dtype == torch::kBFloat16);
    TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
​
    CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); // 确保输入在 CUDA 设备上
​
    // 确保张量是行优先存储的(即最后一个维度是连续的)
    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
    TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
    TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
​
    CHECK_DEVICE(block_table); // 确保分块表在 CUDA 设备上
    TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
    TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
​
    // 获取输入张量的大小
    const auto sizes = q.sizes();
    const int batch_size = sizes[0];
    const int seqlen_q_ori = sizes[1];
    const int num_heads_ori = sizes[2];
    const int head_size = sizes[3];
​
    TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
    TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
​
    const int max_num_blocks_per_seq = block_table.size(1);
    const int num_blocks = kcache.size(0);
    const int page_block_size = kcache.size(1);
    const int num_heads_k = kcache.size(2);
​
    TORCH_CHECK(batch_size > 0, "batch size must be postive");
    TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
​
    if (seqlen_q_ori == 1) { is_causal = false; } // 单个时间步时不启用因果机制
​
    const int ngroups = num_heads_ori / num_heads_k;  // 分组数量
    const int seqlen_q = seqlen_q_ori * ngroups;      // 调整后的查询序列长度
    const int num_heads = num_heads_k;                // 注意力头的数量
​
    // 调整输入张量的形状
    q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size})
        .transpose(2, 3)
        .reshape({batch_size, seqlen_q, num_heads, head_size});
​
    int head_size_k = head_size;
​
    // 检查张量形状
    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
    CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
    if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
    CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
​
    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
    CHECK_DEVICE(seqlens_k);
    CHECK_CONTIGUOUS(seqlens_k);
    CHECK_SHAPE(seqlens_k, batch_size);
​
    at::cuda::CUDAGuard device_guard{(char)q.get_device()}; // 确保设备与输入一致
​
    // 初始化输出张量
    auto opts = q.options();
    at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
    at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
​
    // 设置注意力参数
    Flash_fwd_mla_params params = {};
    params.b = batch_size;                          // 批处理大小
    params.seqlen_q = seqlen_q;                     // 查询序列长度
    params.cu_seqlens_k = seqlens_k.data_ptr<int>();// Key 的累积序列长度
    params.h = num_heads;                           // 注意力头数量
    params.h_h_k_ratio = num_heads / num_heads_k;   // 注意力头的比例
    params.ngroups = ngroups;                       // 分组数量
    params.is_causal = is_causal;                   // 是否是因果注意力
    params.d = head_size;                           // Query/Key 的头大小
    params.d_v = head_size_v;                       // Value 的头大小
    params.scale_softmax = softmax_scale;           // Softmax 的缩放因子
    params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Softmax 的缩放因子(对数形式)
​
    // 设置输入和输出张量的指针和 Stride
    params.q_ptr = q.data_ptr();                    // 查询指针
    params.k_ptr = kcache.data_ptr();               // Key 指针
    params.v_ptr = vcache.data_ptr();               // Value 指针
    params.o_ptr = out.data_ptr();                  // 输出指针
    params.softmax_lse_ptr = softmax_lse.data_ptr(); // Softmax 最终值的指针
​
    params.q_batch_stride = q.stride(0);            // 查询的批处理 Stride
    params.k_batch_stride = kcache.stride(0);       // Key 的批处理 Stride
    params.v_batch_stride = vcache.stride(0);       // Value 的批处理 Stride
    params.o_batch_stride = out.stride(0);          // 输出的批处理 Stride
​
    params.q_row_stride = q.stride(-3);             // 查询的行 Stride
    params.k_row_stride = kcache.stride(-3);        // Key 的行 Stride
    params.v_row_stride = vcache.stride(-3);        // Value 的行 Stride
    params.o_row_stride = out.stride(-3);           // 输出的行 Stride
​
    params.q_head_stride = q.stride(-2);            // 查询的头 Stride
    params.k_head_stride = kcache.stride(-2);       // Key 的头 Stride
    params.v_head_stride = vcache.stride(-2);       // Value 的头 Stride
    params.o_head_stride = out.stride(-2);          // 输出的头 Stride
​
    params.block_table = block_table.data_ptr<int>(); // 分块表的指针
    params.block_table_batch_stride = block_table.stride(0); // 分块表的批处理 Stride
    params.page_block_size = page_block_size;         // 分块的大小
​
    // 检查调度元数据和分块数的参数
    TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
    TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
    CHECK_DEVICE(tile_scheduler_metadata);
    CHECK_CONTIGUOUS(tile_scheduler_metadata);
    params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
    params.num_sm_parts = tile_scheduler_metadata.size(0);
​
    TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
    CHECK_DEVICE(num_splits);
    CHECK_CONTIGUOUS(num_splits);
    params.num_splits_ptr = num_splits.data_ptr<int>();
​
    // 初始化累加变量
    at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
    at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
    params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
    params.oaccum_ptr = out_accum.data_ptr();
​
    // 获取当前 CUDA 流
    auto stream = at::cuda::getCurrentCUDAStream().stream();
​
    TORCH_CHECK(head_size == 576); // 确保头大小满足特定条件
​
    // 调用核心注意力计算函数(CUDA 内核)
    run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
​
    // 重新调整输出张量的形状
    out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v})
        .transpose(2, 3)
        .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
​
    softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups})
        .transpose(2, 3)
        .reshape({batch_size, num_heads_ori, seqlen_q_ori});
​
    return {out, softmax_lse}; // 返回注意力输出和 Softmax 最终值
}
​
// 使用 Pybind11 将 C++ 函数绑定到 Python 扩展
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "FlashMLA"; // 扩展的描述信息
​
    // 绑定函数到 Python
    m.def("get_mla_metadata", &get_mla_metadata);   // 绑定元数据函数
    m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); // 绑定注意力前向计算函数
}
flash_fwd_mla_bf16_sm90.cu
#include "flash_fwd_mla_kernel.h"
​
/**
 * #include "flash_fwd_mla_kernel.h"
这是一个头文件包含指令,用于引入 run_mha_fwd_splitkv_mla 函数的定义或实现。
flash_fwd_mla_kernel.h 文件可能包含函数的声明、相关数据类型和常量的定义等。
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>
这是显式实例化一个模板函数,run_mha_fwd_splitkv_mla 是一个模板函数。
cutlass::bfloat16_t 是模板的类型参数,表示使用的数据类型(bfloat16 浮点格式)。
576 是模板的非类型参数,通常表示某种尺寸或配置。
Flash_fwd_mla_params &params
这是一个参数结构体,通常包含多个成员变量,用以传递函数所需的计算参数。
它可能包含输入张量(如 Query、Key、Value)、输出张量、硬件配置(如线程块大小)以及其他计算相关的元数据。
cudaStream_t stream
这是一个 CUDA 流的标识符,用于指定在 GPU 上执行的内核函数的流。
CUDA 流允许开发者控制 GPU 内核的执行顺序和同步,不同的流中的操作可能会并行执行。
run_mha_fwd_splitkv_mla
这是一个多头注意力(MHA)的前向计算函数,具体实现可能包含以下操作:
计算注意力权重(使用 Query 和 Key)。
应用 Softmax 函数归一化权重。
将权重与 Value 相乘,得到注意力输出。
 */
​
/**
 * 该代码片段将一个基于模板的多头注意力计算函数实例化为一个具体的 CUDA 函数,用于在 GPU 上执行。
它需要预定义的参数结构体(Flash_fwd_mla_params)和 CUDA 流(stream)作为输入。
通过显式实例化模板函数,确保编译器生成特定于 bfloat16 数据类型和参数 576 的优化代码。
 */
// 显式实例化模板函数,将其绑定到特定的数据类型和参数
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(
    Flash_fwd_mla_params &params, // 传递参数,包括张量、硬件配置和其他计算信息
    cudaStream_t stream            // 指定 CUDA 流,用于控制内核的执行顺序和同步
);
flash_fwd_mla_kernel.h
#pragma once
​
#include <cute/tensor.hpp>
#include <cutlass/cutlash.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
​
using namespace cute;
​
#include "named_barrier.h"
#include "utils.h"
#include "softmax.h"
#include "static_switch.h"
#include "flash_mla.h"
/**
 * 模板结构体:Flash_fwd_kernel_traits_mla 定义了多头注意力前向计算的模板参数和数据布局,包括线程块大小、头大小、数据类型等。
   张量化矩阵乘法:make_tiled_mma 用于定义张量化的矩阵乘法操作,支持不同布局和数据类型的矩阵乘法。
   共享内存布局:SmemLayoutQ 和 SmemLayoutK 定义了共享内存中查询和键张量的布局,确保数据对齐和高效访问。
   CUDA 内核函数:flash_fwd_splitkv_mla_kernel 和 flash_fwd_splitkv_mla_combine_kernel 分别负责多头注意力的前向计算和分块结果的合并。
   因果机制:Is_causal 参数用于启用或禁用因果机制,确保在计算时不会泄露未来的序列信息。
   Softmax 归一化:flash::Softmax 用于计算 Softmax 归一化,确保注意力权重的稳定性。
   分片合并:通过共享内存和全局内存存储累积结果,并最终将结果合并到输出张量中。
 * 
 */
​
// 根据头大小选择共享内存(SMEM)的布局
template<typename PrecType, int DIM, int DIM2 = DIM>
constexpr auto getSmemLayoutK() {
    // 根据头大小对齐要求返回相应的 SMEM 布局
    constexpr int headSizeBytes = sizeof(PrecType) * DIM;
    constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
​
    if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
        return GMMA::Layout_K_SW128_Atom<PrecType>{};  // 128 字节对齐
    } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
        return GMMA::Layout_K_SW64_Atom<PrecType>{};   // 64 字节对齐
    } else {
        return GMMA::Layout_K_SW32_Atom<PrecType>{};   // 32 字节对齐
    }
}
​
// 定义多头注意力(MHA)的前向计算内核的模板结构体
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla {
    using Element = elem_type;                 // 数据类型(如 bfloat16)
    using ElementAccum = float;                // 累加数据类型(如 float)
    using index_t = int64_t;                   // 索引类型
​
    static constexpr int kNWarps = kNWarps_;    // 每个线程块的 warp 数量
    static constexpr int kNThreads = kNWarps * 32;  // 每个线程块的线程数量
    static constexpr int kNWarpsS = 4;          // SMEM 累加部分的 warp 数量
    static constexpr int kNThreadsS = kNWarpsS * 32; // SMEM 累加部分的线程数量
​
    static constexpr int kBlockM = kBlockM_;    // 线程块大小(行)
    static constexpr int kBlockN = kBlockN_;    // 线程块大小(列)
    static constexpr int kHeadDim = kHeadDim_;  // 查询/键的头大小
    static_assert(kHeadDim % 32 == 0);          // 确保头大小是 32 的倍数
​
    static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; // 值的头大小
    static_assert(kHeadDimV % 32 == 0);         // 确保值头大小是 32 的倍数
    static_assert(kHeadDimV <= kHeadDim);       // 值头大小不能大于查询/键的头大小
    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; // SMEM 的块大小
    static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;      // 数据重排方式
​
    // 定义张量化矩阵乘法(Tiled MMA)操作
    using TiledMma = decltype(make_tiled_mma(
        cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
                GMMA::Major::K, GMMA::Major::K>(),
        Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
​
    // 定义累加器的部分
    static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
    using TiledMmaO = decltype(make_tiled_mma(
        cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
                GMMA::Major::K, GMMA::Major::MN>(),
        Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
​
    // 定义 SMEM 的布局
    using SmemLayoutQ = decltype(tile_to_shape(
        getSmemLayoutK<Element, kHeadDim>(), 
        Shape<Int<kBlockM>, Int<kHeadDim>>{}));
​
    using SmemLayoutK = decltype(tile_to_shape(
        getSmemLayoutK<Element, kHeadDim, kHeadDimV>(), 
        Shape<Int<kBlockN>, Int<kHeadDim>>{}));
​
    using SmemLayoutV = decltype(tile_to_shape(
        getSmemLayoutK<Element, kHeadDim, kHeadDimV>(), 
        Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
​
    using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
​
    using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
    using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
​
    // 定义 SMEM 累加数据的布局
    using SmemLayoutAtomO = decltype(composition(
        Swizzle<kSwizzle, 3, 3>{}, 
        Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
    
    using SmemLayoutO = decltype(tile_to_shape(
        SmemLayoutAtomO{}, 
        Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
​
    // 定义 SMEM 的拷贝操作
    using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
    using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
​
    // 定义张量加载和存储的配置
    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
    static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
    static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
    using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
    
    static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
    static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
​
    using GmemLayoutAtom = Layout<Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
                                   Stride<Int<kGmemThreadsPerRow>, _1>>;
    using GmemTiledCopy = decltype(make_tiled_copy(
        Copy_Atom<Gmem_copy_struct, Element>{},
        GmemLayoutAtom{},
        Layout<Shape<_1, _8>>{}));  // 每次加载 8 个元素
​
    using GmemLayoutAtomO = Layout<Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
                                   Stride<Int<kGmemThreadsPerRow>, _1>>;
    using GmemTiledCopyO = decltype(make_tiled_copy(
        Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
        GmemLayoutAtomO{},
        Layout<Shape<_1, _8>>{}));  // 每次存储 8 个元素
​
    static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
    static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
    using GmemLayoutAtomOaccum = Layout<Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
                                        Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
    using GmemTiledCopyOaccum = decltype(make_tiled_copy(
        Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
        GmemLayoutAtomOaccum{},
        Layout<Shape<_1, _4>>{}));  // 每次存储 4 个累加元素
};
​
namespace flash {
​
using namespace cute;
​
// 定义共享内存(Shared Storage)的结构
template<typename Kernel_traits>
struct SharedStorageMLA {
    union {
        struct {
            // 查询、键和累积值的 SMEM 存储
            cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
            cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k;  // 双缓冲
            cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
            cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
        };
        struct {
            // 注意力权重的最大值、总和和累加结果的 SMEM 存储
            cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
            cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
            cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
        };
    };
};
​

​
// 定义在完成计算后存储结果的函数模板
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
                                      SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
    // 模板参数和数据类型的定义
    constexpr int kBlockM = Kernel_traits::kBlockM;
    constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
    constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
    using Element = typename Kernel_traits::Element;
    using ElementAccum = typename Kernel_traits::ElementAccum;
    using index_t = typename Kernel_traits::index_t;
​
    const int tidx = threadIdx.x;
​
    // 获取 Tiled MMA 操作的线程分片
    typename Kernel_traits::TiledMmaO tiled_mma_o;
    auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
​
    // 进行 softmax 归一化
    Tensor lse = softmax.template normalize_softmax_lse<!Split, Split>(tOrO, params.scale_softmax);
​
    // 确定输出数据的数据类型
    using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
​
    // 获取输出张量的 SMEM 布局
    Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{});
​
    // 获取 SMEM 拷贝操作的线程分片
    using SmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum>;
    auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
    auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
​
    // 对分片张量进行拷贝
    Tensor rO = flash::convert_type<ElementO>(tOrO);
    Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);
    Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);
​
    __syncthreads();
    cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
​
    // 定义偏移量
    const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
    const index_t row_offset_oaccum = (((Split ? params.num_splits_ptr : NULL) + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
    const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
    const index_t row_offset_lseaccum = (((Split ? params.num_splits_ptr : NULL) + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
​
    // 定义输出张量
    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + row_offset_o), Shape<Int<kBlockM>, Int<kHeadDimV>>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lse), Shape<Int<kBlockM>>{}, Stride<_1>{});
​
    // 获取 GMEM 拷贝操作的线程分片
    using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
    GmemTiledCopyO gmem_tiled_copy_Oaccum;
    auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
    Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);
    Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
​
    __syncthreads();
​
    if (tidx >= kNThreadsS) { return; }
​
    // 拷贝数据到输出张量
    Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
    cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
​
    // 计算输出张量的布局
    Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{});
    Tensor taccOcO = thr_mma_o.partition_C(caccO);
    Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
    if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
        for (int mi = 0; mi < size(taccOcO_row); ++mi) {
            const int row = get<0>(taccOcO_row(mi));
            if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
        }
    }
​
    // 将输出张量存储到全局内存
    Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum)));
    Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
​
    flash::copy<!Split, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
            gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
    );
}
​
// 定义计算注意力权重的函数模板
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params &params,
                                                                   const int bidb, const int bidh, const int m_block,
                                                                   const int n_split_idx, const int seqlen_k,
                                                                   const int n_block_min, const int n_block_max, const bool NoSplit,
                                                                   SharedStorage &shared_storage) {
    // 模板参数和数据类型的定义
    constexpr int kBlockM = Kernel_traits::kBlockM;
    constexpr int kBlockN = Kernel_traits::kBlockN;
    constexpr int kHeadDim = Kernel_traits::kHeadDim;
    constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
    constexpr int kNThreads = Kernel_traits::kNThreads;
    constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
    static_assert(kNThreads == 256 && kNThreadsS == 128);
    using Element = typename Kernel_traits::Element;
    using index_t = typename Kernel_traits::index_t;
​
    const int tidx = threadIdx.x;
    int n_block = n_block_max - 1;
​
    // 定义 SMEM 张量的布局
    Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
    Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
    Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
    Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
​
    Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
    Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
    Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
    Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
    Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
    Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
    Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
    Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
​
    // 定义 Tiled MMA 操作
    typename Kernel_traits::TiledMmaO tiled_mma_o;
    auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
    Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt);
    Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
    clear(tOrO);
​
    // 定义 Softmax 对象
    flash::Softmax<2 * size<1>(tOrO)> softmax;
​
    int warp_group_idx = cutlass::canonical_warp_group_idx();
    if (warp_group_idx == 0) {
        // 获取 Tiled MMA 操作的线程分片
        typename Kernel_traits::TiledMma tiled_mma;
        auto thr_mma = tiled_mma.get_thread_slice(tidx);
        Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
        Tensor tSrK = thr_mma.partition_fragment_B(sK);
​
        if (n_block % 2 == 1) {
            // 对 SMEM 进行双缓冲
            constexpr int sK_offset = size(sK);
            tSrK.data() = tSrK.data() + sK_offset / 8;
            tOrVt.data() = tOrVt.data() + sK_offset / 8;
        }
​
        // 定义掩码步数
        constexpr int n_masking_steps = !Is_causal ? 1 : static_cast<int>(ceil_div(kBlockM, kBlockN)) + 1;
​
        // 循环处理掩码步
        for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
            __syncthreads();
​
            // 执行矩阵乘法
            Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
            flash::gemm<! true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
​
            bool is_masking_step = masking_step > 0;
            bool is_first_masking_step = masking_step == n_masking_steps;
​
            if (is_masking_step) {
                Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
                Tensor tScS = thr_mma.partition_C(cS);
​
                // 根据条件设置掩码
                if constexpr (!Is_causal) {
                    for (int i = 0; i < size(tSrS); ++i) {
                        if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) {
                            tSrS(i) = -INFINITY;
                        }
                    }
                } else {
                    for (int i = 0; i < size(tSrS); ++i) {
                        int row = int(get<0>(tScS(i)));
                        int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
                        if (int(get<1>(tScS(i))) > col_limit_right) {
                            tSrS(i) = -INFINITY;
                        }
                    }
                }
            }
​
            // 计算 Softmax 和缩放因子
            Tensor scale_o = is_first_masking_step
                ? softmax.template softmax<! true, Is_causal>(tSrS, params.scale_softmax_log2)
                : is_masking_step ? softmax.template softmax<! false, Is_causal>(tSrS, params.scale_softmax_log2)
                                  : softmax.template softmax<! false, false>(tSrS, params.scale_softmax_log2);
​
            // 将结果存储到 SMEM
            Tensor rP = flash::convert_type<Element>(tSrS);
            cute::copy(rP, tPsP);
            cute::copy(scale_o, tScale_osScale_o);
​
            cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
​
            // 更新输出张量
            flash::rescale_o(tOrO, scale_o);
​
            // 执行矩阵乘法,计算输出
            Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
            flash::gemm<! false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
​
            // 更新 SMEM 的数据指针
            const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
            tSrK.data() = tSrK.data() + sK_offset / 8;
            tOrVt.data() = tOrVt.data() + sK_offset / 8;
        }
​
        // 将 Softmax 的最大值和总和存储到 SMEM
        cute::copy(softmax.row_max, tRow_maxsRow_max);
        cute::copy(softmax.row_sum, tRow_sumsRow_sum);
        cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
    } else {
        // 获取分块表的指针
        const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
        int cur_block_table = __ldg(&block_table[n_block]);
​
        // 定义 GMEM 张量的布局
        const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
        Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Shape<Int<kBlockM>, Int<kHeadDim>>{}, make_stride(params.q_row_stride, _1{}));
​
        typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
        auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
        Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
        Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
        Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));
        Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
        Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
​
        // 将查询数据加载到 SMEM
        flash::copy<! true, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.seqlen_q - m_block * kBlockM);
​
        // 定义 GMEM 张量的布局
        const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
        Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), Shape<Int<kBlockN>, Int<kHeadDim>>{}, make_stride(params.k_row_stride, _1{}));
​
        typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
        auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
        Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
        Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
        Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));
        Tensor tKcK = gmem_thr_copy_K.partition_S(cK);
        Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
​
        // 对 SMEM 进行双缓冲
        if (n_block % 2 == 1) {
            constexpr int sK_offset = size(sK);
            tKsK.data() = tKsK.data() + sK_offset;
            tOrVt.data() = tOrVt.data() + sK_offset / 8;
        }
​
        // 将键数据加载到 SMEM
        const index_t offset_k = cur_block_table * params.k_batch_stride;
        tKgK.data() = tKgK.data() + offset_k;
        flash::copy<! true, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, seqlen_k - n_block * kBlockN);
        tKgK.data() = tKgK.data() - offset_k;
        cute::cp_async_fence();
​
        // 循环处理分块
        for (; n_block >= n_block_min; --n_block) {
            flash::cp_async_wait<0>();
            __syncthreads();
​
            if (n_block - 1 >= n_block_min) {
                // 对 SMEM 进行双缓冲
                const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
                tKsK.data() = tKsK.data() + sK_offset;
​
                // 加载下一个分块的数据
                cur_block_table = __ldg(&block_table[n_block - 1]);
                const index_t offset_k = cur_block_table * params.k_batch_stride;
                tKgK.data() = tKgK.data() + offset_k;
                flash::copy<! true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
                tKgK.data() = tKgK.data() - offset_k;
                cute::cp_async_fence();
            }
​
            // 同步线程
            cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
​
            // 获取更新后的分块数据
            if (n_block - 2 >= n_block_min) {
                cur_block_table = __ldg(&block_table[n_block - 2]);
            }
​
            // 获取 Tiled MMA 操作的线程分片
            typename Kernel_traits::TiledMma tiled_mma;
            auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
            Tensor rP = make_tensor<Element>(tSrS_layout.layout());
            Tensor scale_o = make_tensor<float>(Shape<_2>{});
            cute::copy(tScale_osScale_o, scale_o);
            cute::copy(tPsP, rP);
​
            // 更新输出张量
            flash::rescale_o(tOrO, scale_o);
            Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
            flash::gemm<! false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
​
            // 更新 SMEM 的数据指针
            const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
            tOrVt.data() = tOrVt.data() + sK_offset / 8;
        }
​
        // 同步线程
        cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
        cute::copy(tRow_maxsRow_max, softmax.row_max);
        cute::copy(tRow_sumsRow_sum, softmax.row_sum);
    }
​
    // 根据条件存储结果
    if (NoSplit) {
        store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
    } else {
        store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
    }
}
​
// 定义 CUDA 内核函数模板,负责多头注意力的前向计算
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
    // 模板参数的定义
    constexpr int kBlockN = Kernel_traits::kBlockN;
    const int m_block = blockIdx.x;
    const int bidh = blockIdx.y;
    const int partition_idx = blockIdx.z;
​
    extern __shared__ char shared_memory[];
    auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
​
    // 获取调度元数据
    int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
    int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
    int begin_idx = tile_scheduler_metadata.x;
    int begin_seqlen = tile_scheduler_metadata.y;
    int end_idx = tile_scheduler_metadata.z;
    int end_seqlen = tile_scheduler_metadata.w;
    if (begin_idx >= params.b) return;
    int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
​
    // 循环处理批量数据
    for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
        const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
        const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
        const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
        const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
        const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
​
        if (batch_id > begin_idx) {
            __syncthreads(); // 同步线程
        }
​
        // 调用计算注意力权重的函数模板
        flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
    }
}
​
// 定义将分块结果合并的 CUDA 内核函数模板
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
__global__ void __launch_bounds__(256, 1, 1)
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
    // 线程和块的配置
    constexpr int kNThreads = 128;
    const int tidx = threadIdx.x;
    const int bidx = blockIdx.x;
    const int hs = params.h * params.seqlen_q;
    const int batch_idx = bidx / hs;
    const int hs_idx = bidx % hs;
​
    // 获取分块信息
    const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
    const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
    FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
    if (actual_num_splits == 1) return;
​
    // 定义共享内存数组
    __shared__ ElementAccum sLseScale[kMaxSplits];
​
    // 定义全局存储的 LSE 累积张量
    const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum), Shape<Int<kMaxSplits>>{}, make_stride(hs));
​
    // 定义全局存储的 LSE 张量
    Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + bidx), Shape<_1>{}, Stride<_1>{});
​
    int warp_idx = cutlass::canonical_warp_idx_sync();
    if (warp_idx == 0) {
        // 每个线程加载多个 LSE 值
        constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
        float local_lse[kNLsePerThread];
​
        for (int i = 0; i < kNLsePerThread; ++i) {
            const int split = i * 32 + tidx;
            local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
        }
​
        // 在线程内计算最大值
        float max_lse = -INFINITY;
        for (int i = 0; i < kNLsePerThread; ++i) {
            max_lse = max(max_lse, local_lse[i]);
        }
​
        // 在 warp 内同步计算全局最大值
        for (int offset = 16; offset >= 1; offset /= 2) {
            max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
        }
​
        // 处理无效情况
        max_lse = max_lse == -INFINITY ? 0.0f : max_lse;
​
        // 计算总和
        float sum_lse = 0;
        for (int i = 0; i < kNLsePerThread; ++i) {
            sum_lse += expf(local_lse[i] - max_lse);
        }
​
        // 在 warp 内同步计算全局总和
        for (int offset = 16; offset >= 1; offset /= 2) {
            sum_lse += __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
        }
​
        // 计算全局 LSE
        float global_lse = sum_lse == 0.f || isnan(sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
        if (tidx == 0) {
            gLSE(0) = global_lse;
        }
​
        // 将缩放因子存储到共享内存
        for (int i = 0; i < kNLsePerThread; ++i) {
            const int split = i * 32 + tidx;
            if (split < actual_num_splits) {
                sLseScale[split] = expf(local_lse[i] - global_lse);
            }
        }
    }
    __syncthreads();
​
    // 计算输出张量的布局
    static_assert(kHeadDimV % kNThreads == 0);
    constexpr int Elements = kHeadDimV / kNThreads;
    const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
​
    // 定义 GMEM 张量的布局
    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum), Shape<Int<kHeadDimV>>{}, Stride<_1>{});
​
    // 定义 GMEM 拷贝操作的线程分片
    using GmemTiledCopyOaccum = decltype(make_tiled_copy(
        Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
        Layout<Shape<Int<kNThreads>>>{},
        Layout<Shape<Int<Elements>>>{}));
    GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
    auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
​
    // 定义张量
    Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
    Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
    Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
    clear(tOrO);
​
    // 循环处理每个分块
    for (int split = 0; split < actual_num_splits; ++split) {
        cute::copy(tOgOaccum, tOrOaccum);
        ElementAccum lse_scale = sLseScale[split];
        for (int i = 0; i < size(tOrO); ++i) {
            tOrO(i) += lse_scale * tOrOaccum(i);
        }
        tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
    }
​
    // 将结果存储到全局内存
    Tensor rO = flash::convert_type<Element>(tOrO);
    const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
    const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
    auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
    Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<Elements>>{}, make_stride(_1()));
    cute::copy(rO, gO);
}
​
} // namespace flash
​
// 定义运行 Flash 分块前向计算的功能函数
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
    // 参数校验
    FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
​
    // 根据因果机制选择内核函数
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
        // 获取 CUDA 内核函数
        auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
​
        // 设置内核函数的共享内存大小
        constexpr size_t smem_size = sizeof(SharedStorage);
        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
​
        // 调用 CUDA 内核函数
        kernel<<<dim3(ceil_div(params.seqlen_q, Kernel_traits::kBlockM), params.h, params.num_sm_parts),
                 Kernel_traits::kNThreads, smem_size, stream>>>(params);
    });
​
    // 检查 CUDA 内核函数的执行结果
    CHECK_CUDA_KERNEL_LAUNCH();
​
    // 调用 CUDA 内核函数,合并分块结果
    dim3 grid_combine(params.b * params.h * params.seqlen_q);
    MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
        auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
            typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
        combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
    });
​
    // 检查 CUDA 内核函数的执行结果
    CHECK_CUDA_KERNEL_LAUNCH();
}
​
// 定义运行 MHA 的分块前向计算的功能函数
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
    static_assert(Headdim == 576); // 确保头大小为 576
    FLASH_ASSERT(params.d_v == 512); // 确保值头大小为 512
    FLASH_ASSERT(params.k_ptr == params.v_ptr); // 确保键和值共享同一张量
​
    using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; // 定义模板结构体
    run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}
​
// 定义生成 MLA 调度元数据的功能函数
static constexpr int MaxBatchSize = 4096;
​
__global__ void __launch_bounds__(256, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
    int *seqlens_k_ptr = params.seqlens_k_ptr;
    int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
    int *num_splits_ptr = params.num_splits_ptr;
    int batch_size = params.batch_size;
    int block_size_n = params.block_size_n;
    int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
    int num_sm_parts = params.num_sm_parts;
​
    // 定义共享内存数组
    __shared__ int num_blocks_shared[MaxBatchSize];
    __shared__ int num_splits_shared[MaxBatchSize];
​
    // 计算总数
    int total_num_blocks = 0;
    for (int i = threadIdx.x; i < batch_size; i += 32) {
        int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
        total_num_blocks += num_blocks + fixed_overhead_num_blocks;
        num_blocks_shared[i] = num_blocks;
    }
​
    // 使用同步指令进行线程间的数据聚合
    for (int offset = 16; offset >= 1; offset /= 2) {
        total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
    }
    __syncwarp();
​
    // 计算调度元数据
    if (threadIdx.x == 0) {
        int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
​
        int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
        num_splits_shared[0] = 0;
​
        for (int i = 0; i < num_sm_parts; ++i) {
            int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
​
            tile_scheduler_metadata0[0] = now_idx;
            tile_scheduler_metadata0[1] = now_block * block_size_n;
            tile_scheduler_metadata1 = now_n_split_idx;
​
            int remain_payload = payload;
​
            while (now_idx < batch_size) {
                int num_blocks = num_blocks_shared[now_idx];
                int now_remain_blocks = num_blocks - now_block;
​
                if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
                    cum_num_splits += now_n_split_idx + 1;
                    num_splits_shared[now_idx + 1] = cum_num_splits;
                    remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
                    ++now_idx;
                    now_block = 0;
                    now_n_split_idx = 0;
                } else {
                    if (remain_payload - fixed_overhead_num_blocks > 0) {
                        now_block += remain_payload - fixed_overhead_num_blocks;
                        ++now_n_split_idx;
                        remain_payload = 0;
                    }
                    break;
                }
            }
​
            tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
            tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
            *reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
            tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
​
            FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
        }
    }
    __syncwarp();
​
    // 将分块数量存储到全局内存
    for (int i = threadIdx.x; i <= batch_size; i += 32) {
        num_splits_ptr[i] = num_splits_shared[i];
    }
}
​
// 定义生成 MLA 调度元数据的功能函数
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
    FLASH_ASSERT(params.batch_size < MaxBatchSize); // 确保批量大小在最大值范围内
​
    // 调用 CUDA 内核函数
    get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
    CHECK_CUDA_KERNEL_LAUNCH();
}
flash_mla.h
#pragma once


/**
 * 参数结构体:Flash_fwd_mla_params 和 Mla_metadata_params 是通过指针和索引传递参数的关键结构体,用于在 GPU 上进行高效的内存访问和计算调度。
   模板函数:run_mha_fwd_splitkv_mla 是一个模板函数,允许针对不同的数据类型和头大小进行优化。
   调度元数据:TileSchedulerMetaDataSize 定义了调度元数据的大小,用于控制 GPU 的计算调度。
 */

 /**  主要功能: 在 GPU 上实现高效的多头注意力计算
  * Flash_fwd_mla_params 结构体定义了多头注意力(MHA)前向计算所需的参数,包括:输入张量(查询、键、值)和输出张量的指针。张量的维度信息(如批量大小、序列长度、头大小)。
计算相关的超参数(如 Softmax 缩放因子、因果机制标志)。调度和分块信息(如分块表、元数据指针)。
run_mha_fwd_splitkv_mla 函数模板调用多头注意力的前向计算功能,支持分块和缓存优化。通过模板参数 T 和 Headdim,可以针对不同的数据类型和头大小进行优化。
Mla_metadata_params 结构体定义了生成多线程计算元数据所需的参数,包括:序列长度数组分块表和分块数指针。SM 的划分数量和固定开销的块数量。
get_mla_metadata_func 函数生成多线程计算的元数据,用于优化 GPU 上的调度和计算。
  */
// 定义多头注意力(MHA)前向计算的参数结构体
struct Flash_fwd_mla_params {
    using index_t = int64_t; // 索引类型

    // 参数描述:
    // b: 批量大小
    // seqlen_q: 查询序列长度
    // d: 查询/键的头大小
    // d_v: 值的头大小
    // h: 注意力头数量
    // h_h_k_ratio: 注意力头的比例
    // ngroups: 分组数量
    // is_causal: 是否启用因果机制
    // scale_softmax: Softmax 缩放因子
    // scale_softmax_log2: Softmax 缩放因子(以 2 为底的对数形式)
    // cu_seqlens_k: 累积序列长度数组(指向 GPU 内存)

    int b, seqlen_q, d, d_v;
    int h, h_h_k_ratio, ngroups;
    bool is_causal;
    float scale_softmax, scale_softmax_log2;
    int *__restrict__ cu_seqlens_k;

    // 张量指针(指向 GPU 内存)
    void *__restrict__ q_ptr;       // 查询张量
    void *__restrict__ k_ptr;       // 键张量
    void *__restrict__ v_ptr;       // 值张量
    void *__restrict__ o_ptr;       // 输出张量
    void *__restrict__ softmax_lse_ptr; // Softmax 最终值张量

    // 张量的 Stride(布局)信息
    index_t q_batch_stride;         // 查询张量的批量 Stride
    index_t k_batch_stride;         // 键张量的批量 Stride
    index_t v_batch_stride;         // 值张量的批量 Stride
    index_t o_batch_stride;         // 输出张量的批量 Stride
    index_t q_row_stride;           // 查询张量的行 Stride
    index_t k_row_stride;           // 键张量的行 Stride
    index_t v_row_stride;           // 值张量的行 Stride
    index_t o_row_stride;           // 输出张量的行 Stride
    index_t q_head_stride;          // 查询张量的头 Stride
    index_t k_head_stride;          // 键张量的头 Stride
    index_t v_head_stride;          // 值张量的头 Stride
    index_t o_head_stride;          // 输出张量的头 Stride

    // 分块表相关参数
    int *__restrict__ block_table;      // 分块表指针
    index_t block_table_batch_stride;   // 分块表的批量 Stride
    int page_block_size;                // 分块大小

    // 调度元数据相关参数
    int *__restrict__ tile_scheduler_metadata_ptr; // 调度元数据指针
    int num_sm_parts;                              // SM 的划分数量
    int *__restrict__ num_splits_ptr;              // 分块数指针

    // 累加值相关指针
    void *__restrict__ softmax_lseaccum_ptr;       // Softmax 最终值累加指针
    void *__restrict__ oaccum_ptr;                 // 输出累加指针
};

// 调度元数据的大小(单位:字节)
static constexpr int TileSchedulerMetaDataSize = 8;
// 元数据的格式为:[begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]



// 定义运行多头注意力(MHA)的前向计算功能函数模板
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);

// 定义生成多线程计算元数据的结构体
struct Mla_metadata_params {
    // 序列长度数组(指向 GPU 内存)
    int *__restrict__ seqlens_k_ptr;
    // 调度元数据指针
    int *__restrict__ tile_scheduler_metadata_ptr;
    // 分块数指针
    int *__restrict__ num_splits_ptr;
    // 批量大小
    int batch_size;
    // 块大小(N 方向)
    int block_size_n;
    // 固定开销的块数量
    int fixed_overhead_num_blocks;
    // SM 的划分数量
    int num_sm_parts;
};

// 定义生成多线程计算元数据的功能函数
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);
 
named_barrier.h
#pragma once

#include "cutlass/barrier.h" // 引入 Cutlass 库中的屏障(Barrier)功能
/**
 * 主要功能:多线程计算中提供一种同步机制,确保不同阶段的计算按顺序执行,避免数据竞争和冲突。
 */

/**
 * cutlass/barrier.h:引入 Cutlass 库中的屏障功能,用于多线程同步。
   命名屏障:通过枚举类定义命名屏障,便于在代码中明确标识不同阶段的同步点。
 */

 /**
  * 命名屏障(NamedBarriers)定义了两个命名屏障,用于在多线程计算中同步不同阶段的执行。
   SReady:表示某个阶段(如数据加载或计算)已准备就绪。
   SoftmaxReady:表示 Softmax 计算已准备就绪。
  避免冲突
    通过为屏障命名,避免了不同计算阶段之间的潜在冲突,确保多线程计算的正确性和高效性。
  */
namespace flash {



// 定义命名屏障的枚举类,用于避免潜在的冲突
enum class NamedBarriers {
    SReady = 1,          // 表示某个阶段已准备就绪的屏障
    SoftmaxReady = 2,    // 表示 Softmax 计算已准备就绪的屏障
};

} // namespace flash
softmax.h
#pragma once

#include <cmath>

#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>

#include "utils.h"

/**
 * 主要功能:为 Flash-Attention 提供高效的计算和归一化操作,针对大规模模型训练进行了优化。
 */

/**
 * 线程内/多线程 reduce 操作:
    thread_reduce_ 和 quad_allreduce_ 分别实现了线程内和多线程间的 reduce 操作,用于将张量按行聚合为一维摘要张量。
    reduce_ 综合了线程内 reduce 和多线程 quad 全规约操作,用于高效的 GPU 并行计算。
张量变换与归一化:
    scale_apply_exp2 定义了对张量应用指数量化,通过指数量变换实现精确的归一化。
    max_scale_exp2_sum 综合了求最大值、指数变换和求和操作,用于计算 Softmax 归一化的中间结果。
Softmax 归一化:
    rescale_o 定义了对输出张量进行归一化的函数,确保归一化因子被正确应用。
    Softmax 模板类实现了对输入张量进行 Softmax 归一化的功能,包括求最大值、指数变换、求和和归一化因子计算等步骤。
误差与优化:
    代码中对 NaN 和无效值进行了处理,避免了计算过程中的错误。
    使用了 Allreduce 和 UNFUSE_FMA 等优化策略,提高了计算性
 */

namespace flash {

using namespace cute;



// 定义线程内 reduce 操作,将张量按行聚合为一维摘要张量
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
    static_assert(Layout0::rank == 2, "Only support 2D Tensor"); // 确保输入张量是二维的
    static_assert(Layout1::rank == 1, "Only support 1D Tensor"); // 确保摘要张量是一维的
    CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));  // 确保摘要张量的行数与输入张量的行数一致

    // 按行对张量进行聚合
    for (int mi = 0; mi < size<0>(tensor); mi++) {
        summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
        for (int ni = 1; ni < size<1>(tensor); ni++) {
            summary(mi) = op(summary(mi), tensor(mi, ni));
        }
    }
}

// 定义多线程间的 quad 全规约操作,确保不同线程中的数据一致
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
    CUTE_STATIC_ASSERT_V(size(dst) == size(src));  // 确保输入和输出张量的大小一致

    for (int i = 0; i < size(dst); i++){
        dst(i) = Allreduce<4>::run(src(i), op);     // 使用 Allreduce 操作进行数据规约
    }
}

// 定义综合 reduce 操作,结合线程内 reduce 和多线程 quad 全规约
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
    thread_reduce_<zero_init>(tensor, summary, op); // 线程内 reduce 操作
    quad_allreduce_(summary, summary, op);          // 线程间 quad 全规约操作
}

// 定义求最大值的 reduce 操作
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
    MaxOp<float> max_op;                            // 定义最大值操作
    reduce_<zero_init>(tensor, max, max_op);        // 使用 reduce_ 操作求最大值
}

// 定义求和的 reduce 操作
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
    SumOp<float> sum_op;                            // 定义求和操作
    thread_reduce_<zero_init>(tensor, sum, sum_op); // 使用线程内 reduce 操作求和
}

// 定义对张量应用指数量化
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
    static_assert(Layout0::rank == 2, "Only support 2D Tensor"); // 确保输入张量是二维的
    static_assert(Layout1::rank == 1, "Only support 1D Tensor"); // 确保最大值张量是一维的
    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));       // 确保最大值张量的行数与输入张量的行数一致

    // 按行对张量进行量级调整和指数变换
    for (int mi = 0; mi < size<0>(tensor); ++mi) {
        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
        for (int ni = 0; ni < size<1>(tensor); ++ni)  {
            #ifdef UNFUSE_FMA
                tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
            #else
                tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
            #endif
        }
    }
    return tensor;
}

// 定义综合 max-scale-exp-sum 操作,包括求最大值、指数变换和求和
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
    static_assert(Layout0::rank == 2, "Only support 2D Tensor"); // 确保输入张量是二维的
    static_assert(Layout1::rank == 1, "Only support 1D Tensor"); // 确保最大值和求和张量是一维的
    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));       // 确保最大值张量的行数与输入张量的行数一致

    // 按行对张量进行 max-scale-exp-sum 操作
    for (int mi = 0; mi < size<0>(tensor); ++mi) {
        MaxOp<float> max_op;                            // 定义最大值操作
        max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
        for (int ni = 1; ni < size<1>(tensor); ni++) {
            max(mi) = max_op(max(mi), tensor(mi, ni));
        }
        max(mi) = Allreduce<4>::run(max(mi), max_op);   // 使用 Allreduce 操作进行数据规约
        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
        sum(mi) = 0;
        for (int ni = 0; ni < size<1>(tensor); ++ni)  {
            tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); // 应用指数变换
            sum(mi) += tensor(mi, ni);                                  // 求和
        }
        SumOp<float> sum_op;                                            // 定义求和操作
        sum(mi) = Allreduce<4>::run(sum(mi), sum_op);                   // 使用 Allreduce 操作进行数据规约
    }
}

// 定义对输出张量进行归一化的函数
template<typename Tensor0, typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
    // 转换 acc_o 的布局为行优先
    Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
    for (int mi = 0; mi < size(scale_o); ++mi) {
        for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { 
            acc_o_rowcol(mi, ni) *= scale_o(mi); // 应用归一化因子
        }
    }
}



// 定义用于计算 Softmax 的模板类
template <int kNRows>
struct Softmax {
    using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{})); // 定义摘要张量类型
    TensorT row_max, row_sum;                                           // 行最大值和行和

    __forceinline__ __device__ Softmax() {}                             // 默认构造函数

    // 定义对输入张量进行 Softmax 操作并返回缩放因子
    template<bool Is_first, bool Check_inf=false, typename Tensor0>
    __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
        // 转换张量布局为行优先
        Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
        static_assert(decltype(size<0>(scores))::value == kNRows);      // 确保输入张量符合布局要求

        TensorT scale_o;                                                // 初始化缩放因子
        clear(scale_o);

        if (Is_first) {
            // 如果是第一次迭代,计算行最大值
            flash::template reduce_max</*zero_init=*/true>(scores, row_max);
            // 应用指数变换
            flash::scale_apply_exp2<scales_max>::apply(scores, row_max, softmax_scale_log2);
            // 计算行和
            flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
        } else {
            // 否则,使用缓存的行最大值继续计算
            Tensor scores_max_prev = make_fragment_like(row_max);
            cute::copy(row_max, scores_max_prev);
            flash::template reduce_max</*zero_init=*/false>(scores, row_max);
            #pragma unroll
            for (int mi = 0; mi < size(row_max); ++mi) {
                float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
                float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
                scale_o(mi) = scores_scale;
                row_sum(mi) *= scores_scale;
            }
            flash::scale_apply_exp2<scales_max>::apply(scores, row_max, softmax_scale_log2);
            flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
        }
        return scale_o;
    }

    // 定义对输入张量进行 Softmax 归一化并返回最终值
    template<bool Is_dropout=false, bool Split=false, typename Tensor0>
    __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
        SumOp<float> sum_op;                                            // 定义求和操作
        quad_allreduce_(row_sum, row_sum, sum_op);                      // 使用多线程 Quad Allreduce 操作进行数据规约

        TensorT lse = make_fragment_like(row_sum);                      // 初始化累加最终值
        Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
        static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); // 确保输入张量符合布局要求

        // 根据行和计算累加最终值
        for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
            float sum = row_sum(mi);
            float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
            lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
            float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
            for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { 
                acc_o_rowcol(mi, ni) *= scale; // 应用归一化因子
            }
        }
        return lse;
    }
};

}  // namespace flash
static_switch.h
#pragma once
/**
 * 主要功能:为 Flash-Attention 提供基本的工具宏,用于错误检查、断言和条件编译。
 */

/**
 * 错误检查
  CHECK_CUDA 和 CHECK_CUDA_KERNEL_LAUNCH:这两个宏用于检查 CUDA 调用是否成功,确保 GPU 操作的安全性和正确性。CHECK_CUDA 用于显式检查 CUDA 函数调用的结果,而 CHECK_CUDA_KERNEL_LAUNCH 则在 CUDA 内核启动后检查是否有错误发生。
断言
  FLASH_ASSERT 和 FLASH_DEVICE_ASSERT:这两个宏用于在代码中插入断言,确保程序状态符合预期。FLASH_ASSERT 用于主机代码,FLASH_DEVICE_ASSERT 用于 GPU 设备代码,当断言失败时,程序会终止并打印错误信息。
条件编译
  BOOL_SWITCH 和 MLA_NUM_SPLITS_SWITCH:这两个宏用于在编译时或运行时根据条件动态选择代码路径。BOOL_SWITCH 根据布尔条件选择不同的实现,而 MLA_NUM_SPLITS_SWITCH 根据分块数量选择不同的实现,从而优化 GPU 的计算性能。
 */
// 定义检查 CUDA 调用是否成功的宏
#define CHECK_CUDA(call)                                                                                  \
    do {                                                                                                  \
        cudaError_t status_ = call;     // 调用 CUDA 函数或操作
        if (status_ != cudaSuccess) {   // 如果调用失败
            fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); // 打印错误信息
            exit(1);                      // 退出程序
        }                                                                                                 \
    } while(0) // 确保宏的行为类似于一条语句

// 定义检查 CUDA 内核启动是否成功的宏
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) // 检查上一次 CUDA 内核调用是否有错误

// 定义运行时断言宏,用于在 CPU 代码中检查条件是否成立
#define FLASH_ASSERT(cond)                                                                                \
    do {                                                                                                  \
        if (not (cond)) {                                                                                 \
            fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond);                 // 打印错误信息
            exit(1);                                                                                      // 退出程序
        }                                                                                                 \
    } while(0)

// 定义设备端断言宏,用于在 GPU 代码中检查条件是否成立
#define FLASH_DEVICE_ASSERT(cond)                                                                         \
    do {                                                                                                  \
        if (not (cond)) {                                                                                 \
            printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond);                          // 打印错误信息
            asm("trap;");                                                                                 // 触发陷阱指令,终止程序
        }                                                                                                 \
    } while(0)

// 定义布尔条件编译宏,用于根据布尔条件选择不同的代码路径
#define BOOL_SWITCH(COND, CONST_NAME, ...)      \
  [&] {                                         \
    if (COND) {                                 \
      constexpr static bool CONST_NAME = true;  // 如果条件成立,设置 CONST_NAME 为 true
      return __VA_ARGS__();                     // 执行相应的代码
    } else {                                    \
      constexpr static bool CONST_NAME = false; // 如果条件不成立,设置 CONST_NAME 为 false
      return __VA_ARGS__();                     // 执行相应的代码
    }                                           \
  }()

// 定义根据分块数量选择不同实现的宏,用于动态选择最优的计算策略
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
  [&] {                                              \
    if (NUM_SPLITS <= 32) {                          \
      constexpr static int NAME = 32;                // 分块数量小于等于 32 时,选择对应实现
      return __VA_ARGS__();                          \
    } else if (NUM_SPLITS <= 64) {                   \
      constexpr static int NAME = 64;                // 分块数量在 33-64 之间时,选择对应实现
      return __VA_ARGS__();                          \
    } else if (NUM_SPLITS <= 96) {                   \
      constexpr static int NAME = 96;                // 分块数量在 65-96 之间时,选择对应实现
      return __VA_ARGS__();                          \
    } else if (NUM_SPLITS <= 128) {                  \
      constexpr static int NAME = 128;               // 分块数量在 97-128 之间时,选择对应实现
      return __VA_ARGS__();                          \
    } else if (NUM_SPLITS <= 160) {                  \
      constexpr static int NAME = 160;               // 分块数量在 129-160 之间时,选择对应实现
      return __VA_ARGS__();                          \
    } else {                                         \
      FLASH_ASSERT(false);                           // 分块数量超过 160 时,触发断言错误
    }                                                \
  }()
utils.h
// 从 https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h 适配而来

#pragma once

#include <assert.h>
#include <stdint.h>
#include <stdlib.h>

#include <cuda_bf16.h>  // 包含 NVIDIA CUDA 的半精度浮点数支持

#include <cute/tensor.hpp>  // 包含 CUTLASS 的张量操作库
#include <cutlass/array.h>  // 包含 CUTLASS 的数组操作
#include <cutlass/cutlass.h> // 包含 CUTLASS 的基础功能
#include <cutlass/numeric_conversion.h> // 包含 CUTLASS 的数值转换功能
#include <cutlass/numeric_types.h>     // 包含 CUTLASS 的数值类型定义

/**
 * 加速和优化深度学习模型中的某些底层操作,尤其是矩阵运算和数据处理步骤,能够显著提高 GPU 上模型训练和推理的性能。
 */

/**
 * 矩阵运算优化:
    提供了自定义的规约操作(Allreduce),能够高效地在 Warp 内进行数据的规约(如最大值或求和操作),可用于加速矩阵运算中的某些步骤。
    实现了一个高效的矩阵乘法(gemm)函数,通过手动展开循环和控制累加器的初始化,能够更好地利用 GPU 的特性和 CUTLASS 的功能,提高矩阵乘法的性能。
张量布局转换:
    提供了两组函数(convert_layout_acc_rowcol 和 convert_layout_acc_Aregs)用于将张量的布局从一种形式转换为另一种形式,以适应不同架构(如 SM80 和 SM90)的要求,确保数据在 GPU 上能够更高效地存储和访问。
异步内存操作优化:
    提供了一个异步内存操作等待函数(cp_async_wait),能够精确地控制异步内存操作的等待时间,减少不必要的等待开销,提高 GPU 的利用率。
数据类型转换:
    提供了一个模板函数(convert_type),能够将张量的数据类型进行转换,支持不同的数值类型(如 float、half 等)之间的转换,增强代码的通用性。
张量拷贝控制:
    提供了一个张量拷贝函数(copy),允许开发者通过布尔参数精确控制拷贝操作的行为,如是否清除超出边界的元素、是否跳过某些依赖项等,提高拷贝效率和灵活性。
最大值和加法运算:
    提供了自定义的最大值和加法运算模板(MaxOp 和 SumOp),能够处理各种数据类型(包括 float),并针对特定类型进行了优化,提高运算速度。
 */



namespace flash { // 定义 flash 命名空间



template<typename T>
struct MaxOp { // 定义一个通用的最大值运算模板
    __device__ __forceinline__ T operator()(T const & x, T const & y) { // 设备端函数,实现二元运算
        return x > y ? x : y; // 返回两个值中的最大值
    }
};

// 为 float 类型特化 MaxOp,使用硬件优化的 max 函数
template <>
struct MaxOp<float> {
    __device__ __forceinline__ float operator()(float const &x, float const &y) {
        return max(x, y); // 更高效地使用硬件 max 指令
    }
};



template<typename T>
struct SumOp { // 定义一个通用的加法运算模板
    __device__ __forceinline__ T operator()(T const & x, T const & y) { // 设备端函数,实现加法
        return x + y; // 返回两个值的和
    }
};



template<int THREADS>
struct Allreduce { // 继续进这个结构体是用于 Warp 级规约的
// Allreduce 模板结构体,用于跨线程块的 reducing 操作
    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4, // 线程数必须是 32/16/8/4
                  "THREADS must be 32, 16, 8, or 4"); 
    template<typename T, typename Operator> // T 是数据类型,Operator 是运算类型
    static __device__ __forceinline__ T run(T x, Operator &op) { // 运行规约操作
        constexpr int OFFSET = THREADS / 2; // 定义线程块的偏移量
        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); // 使用 shuffle 操作进行规约
        return Allreduce<OFFSET>::run(x, op); // 递归调用,不断缩小线程块大小
    }
};

template<> // 特化 Allreduce<2>,因为递归终止条件是线程块大小为 2
struct Allreduce<2> {
    template<typename T, typename Operator>
    static __device__ __forceinline__ T run(T x, Operator &op) {
        x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); // 偏移量为 1,进行最后一次规约
        return x;
    }
};



// 定义一个通用的 GEMM 函数,用于矩阵乘法
template<bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, 
         typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, 
                                      Tensor1 const &tCrB, Tensor2 &tCrC) {

    constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, 
                                            typename TiledMma::FrgTypeA>::value;

    // 需要移除 tCrA 的 const 属性,因为 warpgroup_fence_operand 不支持 const 张量
    if constexpr (Is_RS) { 
        cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); 
    }

    // 对输出张量应用 warpgroup_fence_operand
    warpgroup_fence_operand(tCrC);

    // 根据是否需要同步线程组,执行到来操作
    if constexpr (arrive) { warpgroup_arrive(); }

    // 如果需要初始化为零,设置累加器为 Zero,并手动展开 K 维度的循环
    if constexpr (zero_init) {
        tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // 初始化累加器为零
        CUTLASS_PRAGMA_UNROLL
        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // 执行 GEMM
            tiled_mma.accumulate_ = GMMA::ScaleOut::One; // 累加操作设置为 One
        }
    } else {
        // 默认情况下,手动展开 K 维度的循环
        CUTLASS_PRAGMA_UNROLL
        for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
            cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
            tiled_mma.accumulate_ = GMMA::ScaleOut::One;
        }
    }

    // 根据是否需要提交,执行提交操作
    if constexpr (commit) {
        warpgroup_commit_batch();
    }

    // 等待指定数量的线程组
    if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }

    // 对输出张量应用 query 逻辑
    warpgroup_fence_operand(tCrC);
    if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}



// 对于不同架构(SM80 和 SM90),将 acc_layout 转换为行/列布局
// 转换规则:将 3D 布局转换为二维布局
template<bool Transposed=false, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {

    // 判断是否为 SM90 架构
    if constexpr (decltype(rank<0>(acc_layout))::value == 3) {
        static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
        static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
        static_assert(decltype(rank(acc_layout))::value == 3);
        auto l = acc_layout;

        if constexpr (!Transposed) {
            return make_layout(
                make_layout(get<0, 1>(l), get<1>(l)), 
                make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))
            );
        } else {
            return make_layout(
                make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), 
                make_layout(get<0, 1>(l), get<1>(l))
            );
        }
    } else {
        static_assert(decltype(size<0>(acc_layout))::value == 4);
        static_assert(decltype(rank(acc_layout))::value == 3);
        auto l = logical_divide(acc_layout, Shape<_2>{});
        if constexpr (!Transposed) {
            return make_layout(
                make_layout(get<0, 1>(l), get<1>(l)), 
                make_layout(get<0, 0>(l), get<2>(l))
            );
        } else {
            return make_layout(
                make_layout(get<0, 0>(l), get<2>(l)), 
                make_layout(get<0, 1>(l), get<1>(l))
            );
        }
    }
}



// 转换光照布局为特定的 Aregs 布局
// 对于不同架构(SM80 和 SM90)和数据类型(FP16/BF16 和 FP8),进行不同的转换
template<typename MMA_Traits, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
    using X = Underscore;

    if constexpr (decltype(rank<0>(acc_layout))::value == 3) {
        // SM90 架构的转换逻辑
        static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
        static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
        static_assert(decltype(rank(acc_layout))::value == 3);
        static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
        if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
            auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{});
            return make_layout(
                make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)),
                get<1>(acc_layout),
                coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))
            );
        } else {
            static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
            static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
            static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
            auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>> {});
            return make_layout(
                make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
                get<1>(acc_layout),
                coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))
            );
        }
    } else {
        // SM80 架构的转换逻辑
        static_assert(decltype(size<0>(acc_layout))::value == 4);
        static_assert(decltype(rank(acc_layout))::value == 3);
        constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK {});
        if constexpr (mma_shape_K == 8) {
            return acc_layout;
        } else {
            auto l = logical_divide(acc_layout, Shape<X, X, _2> {});
            return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
        }
    }
}



// 转换数据类型的模板函数
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
    using From_type = typename Engine::value_type; // 获取源数据类型
    constexpr int numel = decltype(size(tensor))::value; // 获取元素数量

    // 使用 CUTLASS 的数值转换器进行类型转换
    cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
    auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
    return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}



// 异步内存操作等待函数,用于优化异步操作的等待时间
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
    asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); // 使用内联汇编调用 cp.async.wait_group
#endif
}



// 张量拷贝函数,用于优化矩阵拷贝操作
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
          typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
                            Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
                            Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {

    // 静态断言,确保张量的维度一致
    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));
    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));
    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));
    static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); // 确保参数不会冲突

    // 循环遍历张量的维度,进行拷贝操作
    #pragma unroll
    for (int m = 0; m < size<1>(S); ++m) {
        if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
            // 满足条件时执行拷贝
            #pragma unroll
            for (int k = 0; k < size<2>(S); ++k) {
                if (Is_even_K || predicate_K(k)) {
                    cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); // 执行拷贝
                } else if (Clear_OOB_K) {
                    cute::clear(D(_, m, k)); // 清除超出边界的元素
                }
            }
        } else if (Clear_OOB_MN) {
            cute::clear(D(_, m, _)); // 清除超出边界的行或列
        }
    }
}



} // namespace flash
init.py
##导入用于机器学习加速(Flash Machine Learning Acceleration,Flash MLA)的接口函数
__version__ = "1.0.0"  # 定义项目的版本号
# 从 flash_mla.flash_mla_interface 模块中导入两个函数
from flash_mla.flash_mla_interface import (
    get_mla_metadata,  # 获取机器学习加速(MLA)的元数据,可能包含模型设置、超参数等信息
    flash_mla_with_kvcache  # 使用键值缓存(KV Cache)的机器学习加速函数,用于优化模型推理或训练过程中的数据缓存
)
flash_mla_interface.py
#  通过高效的 GPU 编程实现大规模注意力机制的计算,大幅度提升深度学习模型的推理和训练效率
from typing import Optional, Tuple  # 导入类型提示相关的模块

import torch  # 导入 PyTorch,用于张量操作和深度学习功能

import flash_mla_cuda  # 导入自定义的 CUDA 扩展模块,用于实现 Flash MLA(机器学习加速)功能


def get_mla_metadata(  # 获取 Flash MLA 的元数据
    cache_seqlens: torch.Tensor,  # 缓存的序列长度,shape 为 (batch_size),dtype torch.int32
    num_heads_per_head_k: int,  # 每个 K 头对应的查询头数量,等于 seq_len_q * num_heads_q // num_heads_k
    num_heads_k: int,  # K 头的数量
) -> Tuple[torch.Tensor, torch.Tensor]:  # 返回两个张量
    """
    Function to retrieve metadata for Flash Machine Learning Acceleration (MLA).

    Args:
        cache_seqlens (torch.Tensor): A 1D tensor of shape (batch_size) containing cache sequence lengths.
        num_heads_per_head_k (int): Number of query heads per K head (equals seq_len_q * num_heads_q // num_heads_k).
        num_heads_k (int): Number of K heads.

    Returns:
        tile_scheduler_metadata (torch.Tensor): Metadata for tile scheduling, shape (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits (torch.Tensor): Array of splits for partitioning, shape (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)  # 调用 CUDA 函数获取元数据


def flash_mla_with_kvcache(  # 使用键值缓存(KV Cache)的 Flash MLA 函数
    q: torch.Tensor,  # 查询张量,shape (batch_size, seq_len_q, num_heads_q, head_dim)
    k_cache: torch.Tensor,  # 键缓存,shape (num_blocks, page_block_size, num_heads_k, head_dim)
    block_table: torch.Tensor,  # 块表格,shape (batch_size, max_num_blocks_per_seq), dtype torch.int32
    cache_seqlens: torch.Tensor,  # 缓存的序列长度,shape (batch_size), dtype torch.int32
    head_dim_v: int,  # V 头的维度
    tile_scheduler_metadata: torch.Tensor,  # 瓦片调度元数据,shape (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32
    num_splits: torch.Tensor,  # 分割数组,shape (batch_size + 1), dtype torch.int32
    softmax_scale: Optional[float] = None,  # softmax 的缩放因子,默认为 1 / sqrt(head_dim)
    causal: bool = False,  # 是否应用因果掩码
) -> Tuple[torch.Tensor, torch.Tensor]:  # 返回两个张量
    """
    Flash Machine Learning Acceleration with KV Cache for efficient attention computation.

    Args:
        q (torch.Tensor): Query tensor of shape (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache (torch.Tensor): Key cache tensor of shape (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table (torch.Tensor): Block table tensor of shape (batch_size, max_num_blocks_per_seq), dtype torch.int32.
        cache_seqlens (torch.Tensor): Cache sequence lengths tensor of shape (batch_size), dtype torch.int32.
        head_dim_v (int): Head dimension for value.
        tile_scheduler_metadata (torch.Tensor): Tile scheduler metadata tensor, shape (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits (torch.Tensor): Array of splits for partitioning, shape (batch_size + 1), dtype torch.int32.
        softmax_scale (float, optional): Scaling factor for softmax. Defaults to 1 / sqrt(head_dim).
        causal (bool, optional): Whether to apply causal attention mask. Defaults to False.

    Returns:
        out (torch.Tensor): Output tensor of shape (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse (torch.Tensor): LogSumExp values for softmax, shape (batch_size, num_heads_q, seq_len_q), dtype torch.float32.
    """
    if softmax_scale is None:  # 如果 softmax_scale 未指定,计算默认值
        softmax_scale = q.shape[-1] ** (-0.5)  # 默认缩放因子为 1 / sqrt(head_dim)

    # 调用 CUDA 函数执行前向计算
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(  
        q,  # 查询张量
        k_cache,  # 键缓存
        None,  # 未使用
        head_dim_v,  # V 头的维度
        cache_seqlens,  # 缓存的序列长度
        block_table,  # 块表格
        softmax_scale,  # softmax 缩放因子
        causal,  # 是否应用因果掩码
        tile_scheduler_metadata,  # 瓦片调度元数据
        num_splits,  # 分割数组
    )

    return out, softmax_lse  # 返回输出张量和 softmax 的 LogSumExp 向量
setup.py
#构建和分发一个包含 CUDA 扩展模块的 Python 包,该扩展模块利用了 CUTLASS 库和 Flash Machine Learning Acceleration(Flash MLA)技术来加速深度学习模型中的矩阵运算和注意力机制计算。
import os  # 导入操作系统相关模块,用于获取环境变量和操作目录
from pathlib import Path  # 用于处理文件路径
from datetime import datetime  # 用于获取当前时间
import subprocess  # 用于调用外部命令
from setuptools import setup, find_packages  # 导入 setuptools,用于打包和分发 Python 包
from torch.utils.cpp_extension import (
    BuildExtension,  # 用于构建 C++/CUDA 扩展的构建扩展类
    CUDAExtension,  # 用于定义 CUDA 扩展
    IS_WINDOWS,  # 是否是 Windows 系统
)

def append_nvcc_threads(nvcc_extra_args):  # 添加多线程支持
    """
    Append NVIDIA CUDA Compiler (nvcc) threads setting to the compilation arguments.

    Args:
        nvcc_extra_args (list): Existing nvcc compilation arguments.

    Returns:
        list: Updated nvcc compilation arguments with thread settings.
    """
    nvcc_threads = os.getenv("NVCC_THREADS") or "32"  # 获取环境变量中的线程数,默认为 32
    return nvcc_extra_args + ["--threads", nvcc_threads]  # 添加线程设置到编译参数

# 更新 Git 子模块 CUTLASS
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])  # 初始化和更新 Git 子模块 CUTLASS,用于获得最新的 CUTLASS 库代码

# 定义 CUDA 编译架构
cc_flag = []  # 初始化 CUDA 架构标志列表
cc_flag.append("-gencode")  # 添加代码生成标志
cc_flag.append("arch=compute_90a,code=sm_90a")  # 指定目标架构为计算架构 90a 和架构 90a

this_dir = os.path.dirname(os.path.abspath(__file__))  # 获取当前脚本所在的目录

# 定义 C++ 编译参数
if IS_WINDOWS:
    cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]  # Windows 系统下的 C++ 编译参数
else:
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]  # 非 Windows 系统下的 C++ 编译参数

# 定义 CUDA 扩展模块
ext_modules = []
ext_modules.append(
    CUDAExtension(  # 定义 CUDA 扩展模块
        name="flash_mla_cuda",  # 扩展模块的名称
        sources=[  # C++ 和 CUDA 源文件路径
            "csrc/flash_api.cpp",
            "csrc/flash_fwd_mla_bf16_sm90.cu",
        ],
        extra_compile_args={  # 额外的编译参数
            "cxx": cxx_args,  # C++ 编译参数
            "nvcc": append_nvcc_threads(  # CUDA 编译参数
                [
                    "-O3",
                    "-std=c++17",
                    "-DNDEBUG",
                    "-D_USE_MATH_DEFINES",
                    "-Wno-deprecated-declarations",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "-U__CUDA_NO_HALF2_OPERATORS__",
                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                    "--expt-relaxed-constexpr",
                    "--expt-extended-lambda",
                    "--use_fast_math",
                    "--ptxas-options=-v,--register-usage-level=10"
                ]
                + cc_flag  # 添加 CUDA 架构参数
            ),
        },
        include_dirs=[  # 包含目录,指定额外的头文件路径
            Path(this_dir) / "csrc",
            Path(this_dir) / "csrc" / "cutlass" / "include",
        ],
    )
)

# 获取 Git 提交哈希或生成时间版本号
try:
    cmd = ['git', 'rev-parse', '--short', 'HEAD']  # 获取 Git 提交哈希的命令
    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()  # 执行命令并获取提交哈希
except Exception as _:
    now = datetime.now()  # 获取当前时间
    date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")  # 格式化时间字符串
    rev = '+' + date_time_str  # 生成基于时间的版本号

# 配置打包信息
setup(
    name="flash_mla",  # 包的名称
    version="1.0.0" + rev,  # 包的版本号
    packages=find_packages(include=['flash_mla']),  # 包含的 Python 模块
    ext_modules=ext_modules,  # 包含的扩展模块
    cmdclass={"build_ext": BuildExtension},  # 构建类,用于构建 C++/CUDA 扩展
)

3. FlashMLA 的技术亮点

3.1 内存优化

FlashMLA 通过以下技术优化内存使用:

  • 分页 KV 缓存:采用块大小为 64 的分页式显存管理,避免了传统连续内存分配导致的显存碎片化问题。这一设计使得单卡能够并行处理超过 200 个对话线程,服务密度提升 3 倍。

  • BF16 精度支持:FlashMLA 支持 BF16 数据格式,显存占用减少 50%。结合低秩压缩技术,KV 缓存体积可压缩至原体积的 1/4,实现 93.3% 的 KV 缓存量削减。

3.2 性能提升

FlashMLA 在 NVIDIA H800 SXM5 GPU 上表现出卓越的性能:

  • 内存带宽:在内存受限的配置下,FlashMLA 能够实现 3000 GB/s 的带宽。

  • 计算性能:在计算受限的配置下,FlashMLA 能够达到 580 TFLOPS 的计算性能。

  • 性能对比:与传统模型推理框架(如 PyTorch)相比,FlashMLA 的性能提升了 10 倍;与 FlashAttention 相比,性能提升了 3-5 倍

3.3 其他优化

FlashMLA 的设计灵感来源于 FlashAttention 2&3 和英伟达的 CUTLASS 项目。它通过以下方式进一步提升推理效率:

  • 线性复杂度设计:FlashMLA 采用线性复杂度的设计,避免了传统 MHA 的二次复杂度瓶颈。

  • 优化 KV 缓存机制:通过高效的 KV 缓存管理,FlashMLA 实现了更高的推理吞吐量和更低的延迟。

4. 应用场景和优势(涵盖硬件)

FlashMLA 适用场景:
  1. 长序列处理:适合处理数千个标记的文本,如文档分析、长对话等。

  2. 实时应用:如聊天机器人、虚拟助手、实时翻译系统等,显著降低延迟,提升用户体验。

  3. 资源效率优化:减少内存和计算需求,便于在边缘设备或资源受限的环境中部署。

硬件适配与成本优势

FlashMLA 目前仅适配 DeepSeek 的架构模型(如 DeepSeek-R1 和 DeepSeek-V3),并专为 NVIDIA H 系列显卡(如 H800 SXM5)进行了优化。通过优化内存和计算效率,FlashMLA 显著降低了服务器的硬件成本。例如:

  • 推理成本降低:单位推理成本大幅降低,使得 AI 公司和云计算服务商能够在相同的 GPU 资源下处理更多请求。

  • 硬件资源优化:通过减少显存占用和提升计算效率,FlashMLA 使得企业能够在有限的硬件资源下实现更高的推理吞吐量。


http://www.kler.cn/a/570691.html

相关文章:

  • DDD该怎么去落地实现(4)多对多关系
  • Docker安装Postgres_16数据库
  • cordova app webpack升级为vite
  • B3DM转换成OBJ
  • 图论-腐烂的橘子
  • 汽车轮胎损伤缺陷分割数据集labelme格式1957张3类别
  • 【Elasticsearch】时间序列数据流(Time Series Data Stream,TSDS)
  • 【原创】Ubuntu 24自动分区后的根目录扩展
  • 配置Spring Boot API接口超时时间(五种)
  • C++学习之C++初识、C++对C语言增强、对C语言扩展
  • 基于vue3 + ts 封装一个自定义的message组件
  • 基于机器学习的智能谣言检测系统
  • java jar包内的jar包如何打补丁
  • 2022 年 12 月青少年软编等考 C 语言五级真题解析
  • 第6篇:面向对象编程重构系统
  • Spring MVC 返回数据
  • 如何实现前端“小手向右指”的效果
  • Zookeeper 的核心引擎:深入解析 ZAB 协议
  • ECharts饼图高级美化技巧:泛光效果实现与间隔布局
  • JavaWeb3、Tomcat