一 源码分析
1.1 函数入口
void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q,
torch::Tensor K,
torch::Tensor V,
torch::Tensor O,
int stages) {
CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
if (stages > 1) {
switch (d)
{
case 32:
launch_flash_attn_mma_stages_split_q_shared_kv<32, 2>(Q, K, V, O);