当前位置: 首页 > article >正文

如何在ms-swift 微调训练deepseekvl2时使用sageattention

        sageattention 据说比flash_atten_2还要快很多。 但是如何在deepseekvl2这训练这里把它用上呢?

        1.本质上sageattention是sdpa,SDPA的全称为Scaled Dot-Product Attention, 属于乘性注意力机制, 简单一句话来说就是,根据Query (Q)与Key之间的匹配度来对Value进行加权,而事实上不管是Query, Key还是Value都来自于输入,因此所谓的SDPA本质上是对输入信息信息进行重组。

       2. sageattention使用了Triton这个包,它可以把python的代码编译成目标机器码,大大加速这个运算的速度。

        3. 从官方例子可以看到,它本质的工作原理就是简单的替换torch.nn.functional.scaled_dot_product_attention这个函数,示例如下:

from sageattention import sageattn
import torch.nn.functional as F
。。。
    F.scaled_dot_product_attention = sageattn

        所以,要用上sageatten,其实只需要原来的模型支持sdpa的注意力机制即可。

        但很不幸,从deepseekvl2官方开源github可以看到DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM都是不支持sdpa的,这个两个类都没有声明_supports_sdpa = True

ATTENTION_CLASSES = {
    "eager": DeepseekV2Attention,
    "flash_attention_2": DeepseekV2FlashAttention2,

    "mla_eager": DeepseekV2Attention,
    "mla_flash_attention_2": DeepseekV2FlashAttention2,

    "mha_eager": LlamaAttention,
    "mha_flash_attention_2": LlamaFlashAttention2
}

        这个attention_class也没有sdpa的实现。

        因此,deepseekvl2无法直接简单使用sageattion,我们需要改一下deepseek的开源代码,才有可能用上sageattion.

        修改步骤如下:

1. 首先要让DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM先支持sdpa,添加_supports_sdpa,这样transformers/modeling_utils.py的_check_and_enable_sdpa才可以检查通过.

class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    _supports_sdpa = True
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
    _supports_sdpa = True

2.然后在DeepSeek-VL2/deepseek_vl2/models/modeling_deepseek.py添加sdpa的attention的实现类,我们让它继承LlamaAttention,如下,其实这个直接抄的LlamaSdpaAttention,copy是为了修改方便,实现如下:


class DeepSeekSdpaAttention(LlamaAttention):
    """
    Deepseek attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `DeepseekV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            if isinstance(position_embeddings, torch.Tensor):
                cos, sin = self.rotary_emb(value_states, position_ids)
            else:
                cos, sin = position_embeddings


        query_states, key_states = apply_rotary_pos_emb2(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True  #if causal_mask is None and q_len > 1 else False

        attn_output = sageattn(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

3. 修改modeling_deepseek.py的ATTENTION_CLASSES,加上sdpa支持,如下:


ATTENTION_CLASSES = {
    "eager": DeepseekV2Attention,
    "flash_attention_2": DeepseekV2FlashAttention2,

    "mla_eager": DeepseekV2Attention,
    "mla_flash_attention_2": DeepseekV2FlashAttention2,

    "mha_eager": LlamaAttention,
    "mha_flash_attention_2": LlamaFlashAttention2,
    "mha_sdpa": DeepSeekSdpaAttention
}

4.使用的sageattion和Triton的版本如下:

Name: sageattention
Version: 1.0.6

Name: triton
Version: 3.2.0

5. 训练测试,

swift sft  --model "deepseek-ai/deepseek-vl2-tiny"  --dataset  ../TEST.json  --attn_impl sdp

。。。
using attn_implementation: mha_sdpa
[INFO:swift] model.hf_device_map: {'': device(type='cuda', index=0)}

。。。
[INFO:swift] End time of running main: 2025-03-28 10:43:29.304164

成功进行了训练。


http://www.kler.cn/a/614363.html

相关文章:

  • Solidity 番外篇 | 最新 EIP 动态、Solidity 版本变化、Web3 职业路径全览
  • 火语言RPA--生成随机/格式化字符串或数字
  • ros2--功能包
  • QT图片轮播器实现方法二(QT实操2)
  • Linux-Centos离线环境安装python3
  • DHCP报文的详细流程
  • linux驱动相关资料,网址链接
  • Effective C++ 剖析(条款10~22)
  • 关于VUE中v-model响应式失效的问题
  • 【杂谈】-人工智能驱动的编码:提升效率还是增加网络安全隐患?
  • Oracle数据库数据编程SQL<3.1 PL/SQL 匿名块 及 流程控制中的条件判断、循环、异常处理和随机函数应用>
  • 知能行每日综测
  • C# .net ai Agent AI视觉应用 写代码 改作业 识别屏幕 标注等
  • 全球化2.0 | ZStack举办香港Partner Day,推动AIOS智塔+DeepSeek海外实践
  • 【云原生】docker 搭建单机PostgreSQL操作详解
  • 【Prometheus】Prometheus的特点、数据采集方式、架构、数据模型详解
  • 【Linux指南】Linux内核:操作系统的核心引擎
  • 前端快速系统学习Rust的路径
  • 【CSS】相对位置小练习
  • 基于springboot+vue的农产品电商平台